| import os |
| import threading |
| import warnings |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from enum import IntEnum |
| from math import ceil |
|
|
| from ..converters.convert_to_ternary8 import pack_ternary, unpack_ternary |
|
|
| _HAS_TILELANG = False |
| try: |
| import tilelang |
| import tilelang.language as T |
| _HAS_TILELANG = True |
| except ImportError: |
| pass |
|
|
| _HAS_TRITON = False |
| try: |
| import triton |
| import triton.language as tl |
| _HAS_TRITON = True |
| except ImportError: |
| pass |
|
|
|
|
| def _backend_preference() -> str: |
| backend = os.environ.get("ARB_TERNARY_BACKEND", "auto").strip().lower() |
| if backend not in {"auto", "tilelang", "triton", "torch"}: |
| warnings.warn( |
| f"Unknown ARB_TERNARY_BACKEND={backend!r}; falling back to auto.", |
| RuntimeWarning, |
| stacklevel=2, |
| ) |
| return "auto" |
| return backend |
|
|
|
|
| def _rmsnorm_triton_max_dim() -> int: |
| raw = os.environ.get("ARB_RMSNORM_TRITON_MAX_DIM", "4096").strip() |
| try: |
| return max(0, int(raw)) |
| except ValueError: |
| warnings.warn( |
| f"Invalid ARB_RMSNORM_TRITON_MAX_DIM={raw!r}; using 4096.", |
| RuntimeWarning, |
| stacklevel=2, |
| ) |
| return 4096 |
|
|
|
|
| def _bigint_corr_strength() -> float: |
| raw = os.environ.get("ARB_BIGINT_CORR_STRENGTH", "4.0").strip() |
| try: |
| return float(raw) |
| except ValueError: |
| warnings.warn( |
| f"Invalid ARB_BIGINT_CORR_STRENGTH={raw!r}; using 4.0.", |
| RuntimeWarning, |
| stacklevel=2, |
| ) |
| return 4.0 |
|
|
|
|
| class _ComponentContext: |
| _local = threading.local() |
|
|
| @classmethod |
| def get(cls): |
| val = getattr(cls._local, "current", None) |
| if val is None: |
| return None, 1.0 |
| return val |
|
|
| @classmethod |
| def set(cls, name, weight=1.0): |
| if name is None: |
| cls._local.current = None |
| else: |
| cls._local.current = (name, weight) |
|
|
| @classmethod |
| def clear(cls): |
| cls._local.current = None |
|
|
|
|
| _COMPONENT_CONTEXT = _ComponentContext |
|
|
|
|
| def _tilelang_training_enabled() -> bool: |
| return os.environ.get("ARB_TILELANG_TRAINING", "0").strip().lower() in {"1", "true", "yes"} |
|
|
|
|
| if _HAS_TILELANG: |
|
|
| tilelang_jit = tilelang.jit(pass_configs={"tl.disable_warp_specialized": True}) |
|
|
| def _ternary_fwd_kernel( |
| M: int, N: int, K: int, group_size: int = 12, |
| corr_strength: float = 4.0, |
| block_M: int = 64, block_N: int = 64, block_K: int = 32, |
| threads: int = 128, num_stages: int = 2, |
| ): |
| gpr = (K + group_size - 1) // group_size |
| cs = corr_strength |
|
|
| @T.prim_func |
| def kernel( |
| x: T.Tensor((M, K), "float16"), |
| T_packed: T.Tensor((N * K + 4) // 5, "uint8"), |
| E: T.Tensor((N * gpr), "int8"), |
| corr_accum: T.Tensor((N * gpr), "int64"), |
| step_counter: T.Tensor((1,), "int64"), |
| output: T.Tensor((M, N), "float32"), |
| ): |
| steps = T.cast(step_counter[0], "int32") |
| with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): |
| x_shared = T.alloc_shared((block_M, block_K), dtype="float16") |
| dq_shared = T.alloc_shared((block_N, block_K), dtype="float16") |
| acc = T.alloc_fragment((block_M, block_N), dtype="float32") |
| T.use_swizzle(10) |
| T.clear(acc) |
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): |
| T.copy(x[bx * block_M, k * block_K], x_shared) |
| for i, j in T.Parallel(block_N, block_K): |
| i_glob = by * block_N + i |
| j_glob = k * block_K + j |
| if i_glob < N and j_glob < K: |
| lin_idx = i_glob * K + j_glob |
| pack_idx = lin_idx // 5 |
| trit_pos = lin_idx % 5 |
| packed_val = T.cast(T_packed[pack_idx], "int32") |
| trit = T.if_then_else( |
| trit_pos == 0, packed_val % 3, |
| T.if_then_else(trit_pos == 1, (packed_val // 3) % 3, |
| T.if_then_else(trit_pos == 2, (packed_val // 9) % 3, |
| T.if_then_else(trit_pos == 3, (packed_val // 27) % 3, |
| (packed_val // 81) % 3)))) |
| sign_val = T.cast(trit, "int32") - 1 |
| exp_idx = i_glob * gpr + j_glob // group_size |
| exp_val = T.cast(E[exp_idx], "int32") |
| ca = T.cast(corr_accum[exp_idx], "int32") |
| den = T.max(steps * group_size, 1) |
| mc = T.cast(ca, "float32") / T.cast(den, "float32") |
| e_adj = T.cast(exp_val, "float32") + mc * cs |
| ecl = T.min(T.max(e_adj, -14.0), 15.0) |
| dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16") |
| T.gemm(x_shared, dq_shared, acc, transpose_B=True) |
| T.copy(acc, output[bx * block_M, by * block_N]) |
| return tilelang_jit(kernel) |
|
|
| def _ternary_grad_x_kernel( |
| M: int, N: int, K: int, group_size: int = 12, |
| corr_strength: float = 4.0, |
| block_M: int = 64, block_N: int = 64, block_K: int = 32, |
| threads: int = 128, num_stages: int = 2, |
| ): |
| gpr = (K + group_size - 1) // group_size |
| cs = corr_strength |
|
|
| @T.prim_func |
| def kernel( |
| grad_y: T.Tensor((M, N), "float16"), |
| T_packed: T.Tensor((N * K + 4) // 5, "uint8"), |
| E: T.Tensor((N * gpr), "int8"), |
| corr_accum: T.Tensor((N * gpr), "int64"), |
| step_counter: T.Tensor((1,), "int64"), |
| output: T.Tensor((M, K), "float32"), |
| ): |
| steps = T.cast(step_counter[0], "int32") |
| with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=threads) as (bx, by): |
| gy_shared = T.alloc_shared((block_M, block_N), dtype="float16") |
| dq_shared = T.alloc_shared((block_N, block_K), dtype="float16") |
| acc = T.alloc_fragment((block_M, block_K), dtype="float32") |
| T.use_swizzle(10) |
| T.clear(acc) |
| for n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages): |
| T.copy(grad_y[bx * block_M, n * block_N], gy_shared) |
| for i, j in T.Parallel(block_N, block_K): |
| i_glob = n * block_N + i |
| j_glob = by * block_K + j |
| if i_glob < N and j_glob < K: |
| lin_idx = i_glob * K + j_glob |
| pack_idx = lin_idx // 5 |
| trit_pos = lin_idx % 5 |
| packed_val = T.cast(T_packed[pack_idx], "int32") |
| trit = T.if_then_else( |
| trit_pos == 0, packed_val % 3, |
| T.if_then_else(trit_pos == 1, (packed_val // 3) % 3, |
| T.if_then_else(trit_pos == 2, (packed_val // 9) % 3, |
| T.if_then_else(trit_pos == 3, (packed_val // 27) % 3, |
| (packed_val // 81) % 3)))) |
| sign_val = T.cast(trit, "int32") - 1 |
| exp_idx = i_glob * gpr + j_glob // group_size |
| exp_val = T.cast(E[exp_idx], "int32") |
| ca = T.cast(corr_accum[exp_idx], "int32") |
| den = T.max(steps * group_size, 1) |
| mc = T.cast(ca, "float32") / T.cast(den, "float32") |
| e_adj = T.cast(exp_val, "float32") + mc * cs |
| ecl = T.min(T.max(e_adj, -14.0), 15.0) |
| dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16") |
| T.gemm(gy_shared, dq_shared, acc) |
| T.copy(acc, output[bx * block_M, by * block_K]) |
| return tilelang_jit(kernel) |
|
|
| _KERNEL_CACHE_FWD = {} |
| _KERNEL_CACHE_GX = {} |
|
|
| def _get_kernel(M, N, K, group_size, mode, corr_strength=4.0): |
| cs = corr_strength |
| if mode == "fwd": |
| cache = _KERNEL_CACHE_FWD |
| key = (M, N, K, group_size, cs) |
| if key not in cache: |
| cache[key] = _ternary_fwd_kernel(M, N, K, group_size, corr_strength=cs) |
| return cache[key] |
| elif mode == "grad_x": |
| cache = _KERNEL_CACHE_GX |
| key = (M, N, K, group_size) |
| if key not in cache: |
| cache[key] = _ternary_grad_x_kernel(M, N, K, group_size) |
| return cache[key] |
| raise ValueError(f"Unknown TileLang kernel mode: {mode}") |
|
|
|
|
| def _get_grad_kernels(M, N, K, group_size): |
| return _get_kernel(M, N, K, group_size, "grad_x") |
|
|
|
|
| class _TernaryLinearFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, module, fwd_kernel): |
| ctx.module = module |
| T_packed = module.T_packed |
| E = module.E |
| shape = tuple(module._T_shape.tolist()) |
| N, K = shape |
| x_2d = x.reshape(-1, K).contiguous() |
| ctx.group_size = module.group_size |
| ctx.shape = shape |
| ctx.x_shape = x.shape |
| comp_name, _ = _COMPONENT_CONTEXT.get() |
| ctx.comp_name = comp_name |
| ctx.x_dtype = x.dtype |
| has_corr = hasattr(module, "corr_accum") and hasattr(module, "step_counter") |
| ctx.save_for_backward(x_2d, T_packed, E) |
| ctx.has_corr = has_corr |
| ctx.step_snapshot = int(module.step_counter.item()) if has_corr else 0 |
| with torch.no_grad(): |
| M = x_2d.shape[0] |
| output = torch.empty(M, N, device=x.device, dtype=torch.float32) |
| if has_corr: |
| fwd_kernel(x_2d.half(), T_packed, E, |
| module.corr_accum.contiguous(), |
| module.step_counter.contiguous(), output) |
| else: |
| fwd_kernel(x_2d.half(), T_packed, E, |
| torch.zeros(N * ((K + module.group_size - 1) // module.group_size), |
| dtype=torch.int64, device=x.device), |
| torch.zeros(1, dtype=torch.int64, device=x.device), output) |
| return output.reshape(*x.shape[:-1], N) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| x_2d, T_packed, E = ctx.saved_tensors |
| group_size = ctx.group_size |
| N, K = ctx.shape |
| M = x_2d.shape[0] |
| grad_2d = grad_output.reshape(-1, N).contiguous() |
| if ctx.has_corr: |
| corr_accum = ctx.module.corr_accum.contiguous() |
| step_counter = torch.tensor([ctx.step_snapshot], dtype=torch.int64, device=x_2d.device) |
| else: |
| corr_accum = torch.zeros(N * ((K + group_size - 1) // group_size), |
| dtype=torch.int64, device=x_2d.device) |
| step_counter = torch.zeros(1, dtype=torch.int64, device=x_2d.device) |
| grad_x_kernel = _get_grad_kernels(M, N, K, group_size) |
| with torch.no_grad(): |
| grad_x = torch.empty(M, K, device=x_2d.device, dtype=torch.float32) |
| grad_x_kernel(grad_2d.half(), T_packed, E, corr_accum, step_counter, grad_x) |
| comp_name = ctx.comp_name |
| if _HAS_TRITON and ctx.has_corr and getattr(ctx.module, "_stream_backward_updates", True): |
| bwd_name, bwd_weight = _COMPONENT_CONTEXT.get() |
| if bwd_name is None: |
| bwd_weight = 1.0 |
| base_step = int(getattr(ctx.module, "_backward_t_accum_step", 1)) |
| corr_step = max(1, int(round(abs(float(bwd_weight)) * base_step))) |
| if bwd_weight < 0: |
| corr_step = -corr_step |
| _triton_accumulate_corr_direct( |
| T_packed, grad_2d, x_2d, ctx.module.corr_accum, |
| N, K, group_size, corr_step=corr_step, |
| ) |
| ctx.module.step_counter.add_(abs(corr_step)) |
| ctx.module._streamed_bigint_backward = True |
| elif _HAS_TRITON: |
| grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, N, K) |
| if comp_name is not None: |
| setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach()) |
| else: |
| ctx.module._hook_grad_T_sign = grad_sign.detach() |
| elif comp_name is not None: |
| setattr(ctx.module, f"_hook_grad_2d_{comp_name}", grad_2d.detach()) |
| setattr(ctx.module, f"_hook_x_2d_{comp_name}", x_2d.detach()) |
| else: |
| ctx.module._hook_grad_2d = grad_2d.detach() |
| ctx.module._hook_x_2d = x_2d.detach() |
| grad_x_reshaped = grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype) |
| return grad_x_reshaped, None, None |
|
|
|
|
| if _HAS_TRITON: |
|
|
| @triton.jit |
| def _triton_ternary_fwd_kernel( |
| x_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| CORR_STRENGTH: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_m = tl.program_id(0) |
| pid_n = tl.program_id(1) |
|
|
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_k = tl.arange(0, BLOCK_K) |
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| for k0 in range(0, K, BLOCK_K): |
| k = k0 + offs_k |
| x = tl.load( |
| x_ptr + offs_m[:, None] * K + k[None, :], |
| mask=(offs_m[:, None] < M) & (k[None, :] < K), |
| other=0.0, |
| ) |
|
|
| lin = offs_n[:, None] * K + k[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load( |
| packed_ptr + pack_idx, |
| mask=(offs_n[:, None] < N) & (k[None, :] < K), |
| other=0, |
| ).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| sign = trit.to(tl.int32) - 1 |
|
|
| e_idx = offs_n[:, None] * GPR + k[None, :] // GROUP_SIZE |
| e_val = tl.load( |
| e_ptr + e_idx, |
| mask=(offs_n[:, None] < N) & (k[None, :] < K), |
| other=0, |
| ).to(tl.float32) |
| corr_val = tl.load( |
| corr_ptr + e_idx, |
| mask=(offs_n[:, None] < N) & (k[None, :] < K), |
| other=0, |
| ).to(tl.float32) |
| step_val = tl.load(step_ptr).to(tl.float32) |
| denom = tl.maximum(step_val * GROUP_SIZE, 1.0) |
| e_adj = e_val + (corr_val / denom) * CORR_STRENGTH |
| w = sign.to(tl.float32) * tl.exp2(e_adj) |
| w = tl.where((offs_n[:, None] < N) & (k[None, :] < K), w, 0.0) |
| acc += tl.dot(x, tl.trans(w)) |
|
|
| tl.store( |
| out_ptr + offs_m[:, None] * N + offs_n[None, :], |
| acc, |
| mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_grad_x_kernel( |
| grad_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| CORR_STRENGTH: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_m = tl.program_id(0) |
| pid_k = tl.program_id(1) |
|
|
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) |
| offs_n = tl.arange(0, BLOCK_N) |
| acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) |
|
|
| for n0 in range(0, N, BLOCK_N): |
| n = n0 + offs_n |
| grad = tl.load( |
| grad_ptr + offs_m[:, None] * N + n[None, :], |
| mask=(offs_m[:, None] < M) & (n[None, :] < N), |
| other=0.0, |
| ) |
|
|
| lin = n[:, None] * K + offs_k[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load( |
| packed_ptr + pack_idx, |
| mask=(n[:, None] < N) & (offs_k[None, :] < K), |
| other=0, |
| ).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| sign = trit.to(tl.int32) - 1 |
|
|
| e_idx = n[:, None] * GPR + offs_k[None, :] // GROUP_SIZE |
| e_val = tl.load( |
| e_ptr + e_idx, |
| mask=(n[:, None] < N) & (offs_k[None, :] < K), |
| other=0, |
| ).to(tl.float32) |
| corr_val = tl.load( |
| corr_ptr + e_idx, |
| mask=(n[:, None] < N) & (offs_k[None, :] < K), |
| other=0, |
| ).to(tl.float32) |
| step_val = tl.load(step_ptr).to(tl.float32) |
| denom = tl.maximum(step_val * GROUP_SIZE, 1.0) |
| e_adj = e_val + (corr_val / denom) * CORR_STRENGTH |
| w = sign.to(tl.float32) * tl.exp2(e_adj) |
| w = tl.where((n[:, None] < N) & (offs_k[None, :] < K), w, 0.0) |
| acc += tl.dot(grad, w) |
|
|
| tl.store( |
| out_ptr + offs_m[:, None] * K + offs_k[None, :], |
| acc, |
| mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_grad_sign_kernel( |
| grad_ptr, x_ptr, sign_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_n = tl.program_id(0) |
| pid_k = tl.program_id(1) |
|
|
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) |
| offs_m = tl.arange(0, BLOCK_M) |
| acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) |
|
|
| for m0 in range(0, M, BLOCK_M): |
| m = m0 + offs_m |
| grad = tl.load( |
| grad_ptr + m[:, None] * N + offs_n[None, :], |
| mask=(m[:, None] < M) & (offs_n[None, :] < N), |
| other=0.0, |
| ) |
| x = tl.load( |
| x_ptr + m[:, None] * K + offs_k[None, :], |
| mask=(m[:, None] < M) & (offs_k[None, :] < K), |
| other=0.0, |
| ) |
| acc += tl.dot(tl.trans(grad), x, input_precision="ieee") |
|
|
| sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)) |
| tl.store( |
| sign_ptr + offs_n[:, None] * K + offs_k[None, :], |
| sign.to(tl.int8), |
| mask=(offs_n[:, None] < N) & (offs_k[None, :] < K), |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_update_e_kernel( |
| packed_ptr, grad_sign_ptr, e_ptr, e_accum_ptr, |
| N: tl.constexpr, K: tl.constexpr, |
| GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, |
| E_ACCUM_THRESHOLD: tl.constexpr, |
| BLOCK_N: tl.constexpr, BLOCK_G: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_n = tl.program_id(0) |
| pid_g = tl.program_id(1) |
|
|
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_g = pid_g * BLOCK_G + tl.arange(0, BLOCK_G) |
| offs_r = tl.arange(0, BLOCK_K) |
| k = offs_g[:, None] * GROUP_SIZE + offs_r[None, :] |
| valid_group = offs_g < GPR |
|
|
| lin = offs_n[:, None, None] * K + k[None, :, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load( |
| packed_ptr + pack_idx, |
| mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K), |
| other=0, |
| ).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| ternary = trit.to(tl.int32) - 1 |
|
|
| grad_sign = tl.load( |
| grad_sign_ptr + offs_n[:, None, None] * K + k[None, :, :], |
| mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K), |
| other=0, |
| ).to(tl.int32) |
| contrib = grad_sign * ternary |
| score = tl.sum(contrib, axis=2) |
| delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0)) |
|
|
| e_idx = offs_n[:, None] * GPR + offs_g[None, :] |
| old_accum = tl.load( |
| e_accum_ptr + e_idx, |
| mask=(offs_n[:, None] < N) & valid_group[None, :], |
| other=0, |
| ).to(tl.int32) |
| new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta)) |
| step_up = new_accum >= E_ACCUM_THRESHOLD |
| step_down = new_accum <= -E_ACCUM_THRESHOLD |
| e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0)) |
| stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD |
|
|
| old_e = tl.load( |
| e_ptr + e_idx, |
| mask=(offs_n[:, None] < N) & valid_group[None, :], |
| other=0, |
| ).to(tl.int32) |
| new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step)) |
| tl.store( |
| e_ptr + e_idx, |
| new_e.to(tl.int8), |
| mask=(offs_n[:, None] < N) & valid_group[None, :], |
| ) |
| tl.store( |
| e_accum_ptr + e_idx, |
| stored_accum.to(tl.int8), |
| mask=(offs_n[:, None] < N) & valid_group[None, :], |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_step_kernel( |
| packed_ptr, grad_sign_ptr, accum_ptr, per_group_threshold_ptr, |
| TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr, |
| T_ACCUM_STEP: tl.constexpr, |
| K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| HAS_PER_GROUP_THRESHOLD: tl.constexpr, |
| BLOCK_T: tl.constexpr, |
| ): |
| pack_idx = tl.program_id(0) |
| offs_t = tl.arange(0, BLOCK_T) |
| valid_trit = offs_t < 5 |
| lin = pack_idx * 5 + offs_t |
| valid = valid_trit & (lin < TOTAL) |
|
|
| old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32) |
| divisor = tl.where( |
| offs_t == 0, 1, |
| tl.where(offs_t == 1, 3, |
| tl.where(offs_t == 2, 9, |
| tl.where(offs_t == 3, 27, 81))), |
| ) |
| old_code = (old_packed // divisor) % 3 |
| old_sign = old_code.to(tl.int32) - 1 |
|
|
| grad_sign = tl.load(grad_sign_ptr + lin, mask=valid, other=0).to(tl.int32) |
| old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) |
| new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP)) |
|
|
| if HAS_PER_GROUP_THRESHOLD: |
| n = lin // K |
| k = lin - n * K |
| g_idx = n * GPR + k // GROUP_SIZE |
| threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32) |
| else: |
| threshold = ACCUM_THRESHOLD |
|
|
| flip_up = new_accum > threshold |
| flip_down = new_accum < -threshold |
| did_flip = valid & (flip_up | flip_down) |
| new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign)) |
| stored_accum = tl.where(did_flip, 0, new_accum) |
| tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid) |
|
|
| new_code = tl.where(valid, new_sign + 1, 0) |
| packed_val = tl.sum(new_code * divisor, axis=0) |
| tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8)) |
|
|
|
|
| @triton.jit |
| def _triton_update_e_direct_kernel( |
| packed_ptr, grad_ptr, x_ptr, e_ptr, e_accum_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, |
| E_ACCUM_THRESHOLD: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_n = tl.program_id(0) |
| pid_g = tl.program_id(1) |
|
|
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_r = tl.arange(0, BLOCK_K) |
| k = pid_g * GROUP_SIZE + offs_r |
| offs_m = tl.arange(0, BLOCK_M) |
| acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) |
|
|
| for m0 in range(0, M, BLOCK_M): |
| m = m0 + offs_m |
| grad = tl.load( |
| grad_ptr + m[:, None] * N + offs_n[None, :], |
| mask=(m[:, None] < M) & (offs_n[None, :] < N), |
| other=0.0, |
| ) |
| x = tl.load( |
| x_ptr + m[:, None] * K + k[None, :], |
| mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| other=0.0, |
| ) |
| acc += tl.dot(tl.trans(grad), x, input_precision="ieee") |
|
|
| grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) |
| lin = offs_n[:, None] * K + k[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load( |
| packed_ptr + pack_idx, |
| mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| other=0, |
| ).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| ternary = trit.to(tl.int32) - 1 |
| contrib = tl.where( |
| (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| grad_sign * ternary, |
| 0, |
| ) |
| score = tl.sum(contrib, axis=1) |
| delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0)) |
|
|
| e_idx = offs_n * GPR + pid_g |
| old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32) |
| new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta)) |
| step_up = new_accum >= E_ACCUM_THRESHOLD |
| step_down = new_accum <= -E_ACCUM_THRESHOLD |
| e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0)) |
| stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD |
|
|
| old_e = tl.load(e_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32) |
| new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step)) |
| tl.store(e_ptr + e_idx, new_e.to(tl.int8), mask=offs_n < N) |
| tl.store(e_accum_ptr + e_idx, stored_accum.to(tl.int8), mask=offs_n < N) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_step_direct_kernel( |
| packed_ptr, grad_ptr, x_ptr, accum_ptr, per_group_threshold_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr, |
| T_ACCUM_STEP: tl.constexpr, |
| GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| HAS_PER_GROUP_THRESHOLD: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr, |
| ): |
| pack_idx = tl.program_id(0) |
| offs_t = tl.arange(0, BLOCK_T) |
| lin = pack_idx * 5 + offs_t |
| valid_trit = offs_t < 5 |
| valid = valid_trit & (lin < TOTAL) |
| n = lin // K |
| k = lin - n * K |
|
|
| offs_m = tl.arange(0, BLOCK_M) |
| acc = tl.zeros((BLOCK_T,), dtype=tl.float32) |
| for m0 in range(0, M, BLOCK_M): |
| m = m0 + offs_m |
| grad = tl.load( |
| grad_ptr + m[:, None] * N + n[None, :], |
| mask=(m[:, None] < M) & valid[None, :], |
| other=0.0, |
| ) |
| x = tl.load( |
| x_ptr + m[:, None] * K + k[None, :], |
| mask=(m[:, None] < M) & valid[None, :], |
| other=0.0, |
| ) |
| acc += tl.sum(grad * x, axis=0) |
|
|
| grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) |
|
|
| old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32) |
| divisor = tl.where( |
| offs_t == 0, 1, |
| tl.where(offs_t == 1, 3, |
| tl.where(offs_t == 2, 9, |
| tl.where(offs_t == 3, 27, 81))), |
| ) |
| old_code = (old_packed // divisor) % 3 |
| old_sign = old_code.to(tl.int32) - 1 |
|
|
| old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) |
| new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP)) |
|
|
| if HAS_PER_GROUP_THRESHOLD: |
| g_idx = n * GPR + k // GROUP_SIZE |
| threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32) |
| else: |
| threshold = ACCUM_THRESHOLD |
|
|
| flip_up = new_accum > threshold |
| flip_down = new_accum < -threshold |
| did_flip = valid & (flip_up | flip_down) |
| new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign)) |
| stored_accum = tl.where(did_flip, 0, new_accum) |
| tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid) |
|
|
| new_code = tl.where(valid, new_sign + 1, 0) |
| packed_val = tl.sum(new_code * divisor, axis=0) |
| tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8)) |
|
|
|
|
| @triton.jit |
| def _triton_accumulate_t_direct_kernel( |
| grad_ptr, x_ptr, accum_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| TOTAL: tl.constexpr, T_ACCUM_STEP: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr, |
| ): |
| pack_idx = tl.program_id(0) |
| offs_t = tl.arange(0, BLOCK_T) |
| lin = pack_idx * 5 + offs_t |
| valid_trit = offs_t < 5 |
| valid = valid_trit & (lin < TOTAL) |
| n = lin // K |
| k = lin - n * K |
|
|
| offs_m = tl.arange(0, BLOCK_M) |
| acc = tl.zeros((BLOCK_T,), dtype=tl.float32) |
| for m0 in range(0, M, BLOCK_M): |
| m = m0 + offs_m |
| grad = tl.load( |
| grad_ptr + m[:, None] * N + n[None, :], |
| mask=(m[:, None] < M) & valid[None, :], |
| other=0.0, |
| ) |
| x = tl.load( |
| x_ptr + m[:, None] * K + k[None, :], |
| mask=(m[:, None] < M) & valid[None, :], |
| other=0.0, |
| ) |
| acc += tl.sum(grad * x, axis=0) |
|
|
| grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) |
| old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) |
| new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP)) |
| tl.store(accum_ptr + lin, new_accum.to(tl.int8), mask=valid) |
|
|
|
|
| @triton.jit |
| def _triton_accumulate_e_direct_kernel( |
| packed_ptr, grad_ptr, x_ptr, e_accum_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, |
| E_ACCUM_STEP: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_n = tl.program_id(0) |
| pid_g = tl.program_id(1) |
|
|
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_r = tl.arange(0, BLOCK_K) |
| k = pid_g * GROUP_SIZE + offs_r |
| offs_m = tl.arange(0, BLOCK_M) |
| acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) |
|
|
| for m0 in range(0, M, BLOCK_M): |
| m = m0 + offs_m |
| grad = tl.load( |
| grad_ptr + m[:, None] * N + offs_n[None, :], |
| mask=(m[:, None] < M) & (offs_n[None, :] < N), |
| other=0.0, |
| ) |
| x = tl.load( |
| x_ptr + m[:, None] * K + k[None, :], |
| mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| other=0.0, |
| ) |
| acc += tl.dot(tl.trans(grad), x, input_precision="ieee") |
|
|
| grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) |
| lin = offs_n[:, None] * K + k[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load( |
| packed_ptr + pack_idx, |
| mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| other=0, |
| ).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| ternary = trit.to(tl.int32) - 1 |
| contrib = tl.where( |
| (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| grad_sign * ternary, |
| 0, |
| ) |
| score = tl.sum(contrib, axis=1) |
| delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0)) |
|
|
| e_idx = offs_n * GPR + pid_g |
| old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32) |
| new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta * E_ACCUM_STEP)) |
| tl.store(e_accum_ptr + e_idx, new_accum.to(tl.int8), mask=offs_n < N) |
|
|
|
|
| @triton.jit |
| def _triton_accumulate_corr_direct_kernel( |
| packed_ptr, grad_ptr, x_ptr, corr_ptr, |
| M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, |
| CORR_STEP: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| pid_n = tl.program_id(0) |
| pid_g = tl.program_id(1) |
|
|
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_r = tl.arange(0, BLOCK_K) |
| k = pid_g * GROUP_SIZE + offs_r |
| offs_m = tl.arange(0, BLOCK_M) |
| acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) |
|
|
| for m0 in range(0, M, BLOCK_M): |
| m = m0 + offs_m |
| grad = tl.load( |
| grad_ptr + m[:, None] * N + offs_n[None, :], |
| mask=(m[:, None] < M) & (offs_n[None, :] < N), |
| other=0.0, |
| ) |
| x = tl.load( |
| x_ptr + m[:, None] * K + k[None, :], |
| mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| other=0.0, |
| ) |
| acc += tl.dot(tl.trans(grad), x, input_precision="ieee") |
|
|
| grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) |
| lin = offs_n[:, None] * K + k[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load( |
| packed_ptr + pack_idx, |
| mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| other=0, |
| ).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| ternary = trit.to(tl.int32) - 1 |
| contrib = tl.where( |
| (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), |
| grad_sign * ternary, |
| 0, |
| ) |
| score = tl.sum(contrib, axis=1) |
|
|
| corr_idx = offs_n * GPR + pid_g |
| old_corr = tl.load(corr_ptr + corr_idx, mask=offs_n < N, other=0).to(tl.int64) |
| new_corr = old_corr - score.to(tl.int64) * CORR_STEP |
| tl.store(corr_ptr + corr_idx, new_corr, mask=offs_n < N) |
|
|
|
|
| @triton.jit |
| def _triton_apply_accumulated_flips_kernel( |
| packed_ptr, accum_ptr, per_group_threshold_ptr, |
| TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr, |
| K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| HAS_PER_GROUP_THRESHOLD: tl.constexpr, |
| BLOCK_T: tl.constexpr, |
| ): |
| pack_idx = tl.program_id(0) |
| offs_t = tl.arange(0, BLOCK_T) |
| valid_trit = offs_t < 5 |
| lin = pack_idx * 5 + offs_t |
| valid = valid_trit & (lin < TOTAL) |
|
|
| old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32) |
| divisor = tl.where( |
| offs_t == 0, 1, |
| tl.where(offs_t == 1, 3, |
| tl.where(offs_t == 2, 9, |
| tl.where(offs_t == 3, 27, 81))), |
| ) |
| old_code = (old_packed // divisor) % 3 |
| old_sign = old_code.to(tl.int32) - 1 |
|
|
| old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) |
| if HAS_PER_GROUP_THRESHOLD: |
| n = lin // K |
| k = lin - n * K |
| g_idx = n * GPR + k // GROUP_SIZE |
| threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32) |
| else: |
| threshold = ACCUM_THRESHOLD |
|
|
| flip_up = old_accum > threshold |
| flip_down = old_accum < -threshold |
| did_flip = valid & (flip_up | flip_down) |
| new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign)) |
| stored_accum = tl.where(did_flip, 0, old_accum) |
| tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid) |
|
|
| new_code = tl.where(valid, new_sign + 1, 0) |
| packed_val = tl.sum(new_code * divisor, axis=0) |
| tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8)) |
|
|
|
|
| def _triton_ternary_forward(x_2d, packed, e, corr_accum, step_counter, n_out, k_in, group_size): |
| block_m, block_n, block_k = 16, 16, 32 |
| out = torch.empty((x_2d.shape[0], n_out), device=x_2d.device, dtype=torch.float32) |
| grid = (triton.cdiv(x_2d.shape[0], block_m), triton.cdiv(n_out, block_n)) |
| _triton_ternary_fwd_kernel[grid]( |
| x_2d, packed, e, corr_accum, step_counter, out, |
| x_2d.shape[0], n_out, k_in, ceil(k_in / group_size), group_size, |
| _bigint_corr_strength(), |
| BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, |
| ) |
| return out |
|
|
|
|
| def _triton_ternary_grad_x(grad_2d, packed, e, corr_accum, step_counter, m_rows, n_out, k_in, group_size): |
| block_m, block_n, block_k = 16, 16, 32 |
| out = torch.empty((m_rows, k_in), device=grad_2d.device, dtype=torch.float32) |
| grid = (triton.cdiv(m_rows, block_m), triton.cdiv(k_in, block_k)) |
| _triton_ternary_grad_x_kernel[grid]( |
| grad_2d, packed, e, corr_accum, step_counter, out, |
| m_rows, n_out, k_in, ceil(k_in / group_size), group_size, |
| _bigint_corr_strength(), |
| BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, |
| ) |
| return out |
|
|
|
|
| def _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in): |
| block_m, block_n, block_k = 32, 16, 32 |
| out = torch.empty((n_out, k_in), device=grad_2d.device, dtype=torch.int8) |
| grid = (triton.cdiv(n_out, block_n), triton.cdiv(k_in, block_k)) |
| _triton_ternary_grad_sign_kernel[grid]( |
| grad_2d, x_2d, out, |
| x_2d.shape[0], n_out, k_in, |
| BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, |
| ) |
| return out |
|
|
|
|
| def _triton_update_e(packed, grad_sign, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4): |
| block_n, block_g = 8, 4 |
| gpr = ceil(k_in / group_size) |
| block_k = 1 << (group_size - 1).bit_length() |
| grid = (triton.cdiv(n_out, block_n), triton.cdiv(gpr, block_g)) |
| _triton_update_e_kernel[grid]( |
| packed, grad_sign, e, e_accum, |
| n_out, k_in, group_size, gpr, int(e_accum_threshold), |
| BLOCK_N=block_n, BLOCK_G=block_g, BLOCK_K=block_k, |
| ) |
|
|
|
|
| def _triton_update_e_direct(packed, grad_2d, x_2d, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4): |
| block_m, block_n = 32, 8 |
| block_k = 1 << (group_size - 1).bit_length() |
| gpr = ceil(k_in / group_size) |
| grid = (triton.cdiv(n_out, block_n), gpr) |
| _triton_update_e_direct_kernel[grid]( |
| packed, grad_2d, x_2d, e, e_accum, |
| x_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_threshold), |
| BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, |
| ) |
|
|
|
|
| def _triton_ternary_step(packed, grad_sign, accum, total, accum_threshold, t_accum_step=1, |
| per_group_threshold=None, n_out=0, k_in=0, group_size=0): |
| block_t = 8 |
| grid = (triton.cdiv(total, 5),) |
| has_pgt = per_group_threshold is not None |
| dummy = torch.empty(1, device=accum.device, dtype=torch.int8) |
| gpr = (k_in + group_size - 1) // group_size if has_pgt else 0 |
| _triton_ternary_step_kernel[grid]( |
| packed, grad_sign, accum, |
| per_group_threshold if has_pgt else dummy, |
| total, accum_threshold, int(t_accum_step), |
| k_in if has_pgt else 0, gpr, group_size if has_pgt else 0, |
| has_pgt, |
| BLOCK_T=block_t, |
| ) |
|
|
|
|
| def _triton_ternary_step_direct(packed, grad_2d, x_2d, accum, n_out, k_in, total, accum_threshold, t_accum_step=1, |
| per_group_threshold=None, group_size=0): |
| block_m, block_t = 32, 8 |
| grid = (triton.cdiv(total, 5),) |
| has_pgt = per_group_threshold is not None |
| dummy = torch.empty(1, device=accum.device, dtype=torch.int8) |
| gpr = (k_in + group_size - 1) // group_size if has_pgt else 0 |
| _triton_ternary_step_direct_kernel[grid]( |
| packed, grad_2d, x_2d, accum, |
| per_group_threshold if has_pgt else dummy, |
| x_2d.shape[0], n_out, k_in, |
| total, accum_threshold, int(t_accum_step), |
| gpr, group_size if has_pgt else 0, |
| has_pgt, |
| BLOCK_M=block_m, BLOCK_T=block_t, |
| ) |
|
|
|
|
| def _triton_accumulate_direct(packed, grad_2d, x_2d, t_accum, e_accum, |
| n_out, k_in, group_size, |
| t_accum_step=1, e_accum_step=1, |
| update_scales=True): |
| block_m, block_t = 32, 8 |
| total = n_out * k_in |
| grid = (triton.cdiv(total, 5),) |
| _triton_accumulate_t_direct_kernel[grid]( |
| grad_2d, x_2d, t_accum, |
| grad_2d.shape[0], n_out, k_in, total, int(t_accum_step), |
| BLOCK_M=block_m, BLOCK_T=block_t, |
| ) |
| if update_scales and e_accum is not None: |
| block_n = 8 |
| block_k = 1 << (group_size - 1).bit_length() |
| gpr = ceil(k_in / group_size) |
| grid_e = (triton.cdiv(n_out, block_n), gpr) |
| _triton_accumulate_e_direct_kernel[grid_e]( |
| packed, grad_2d, x_2d, e_accum, |
| grad_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_step), |
| BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, |
| ) |
|
|
|
|
| def _triton_accumulate_corr_direct(packed, grad_2d, x_2d, corr_accum, |
| n_out, k_in, group_size, corr_step=1): |
| block_m, block_n = 32, 8 |
| block_k = 1 << (group_size - 1).bit_length() |
| gpr = ceil(k_in / group_size) |
| grid = (triton.cdiv(n_out, block_n), gpr) |
| _triton_accumulate_corr_direct_kernel[grid]( |
| packed, grad_2d, x_2d, corr_accum, |
| grad_2d.shape[0], n_out, k_in, group_size, gpr, int(corr_step), |
| BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, |
| ) |
|
|
|
|
| def _triton_apply_accumulated_flips(packed, accum, total, accum_threshold, |
| per_group_threshold=None, |
| k_in=0, group_size=0): |
| block_t = 8 |
| grid = (triton.cdiv(total, 5),) |
| has_pgt = per_group_threshold is not None |
| dummy = torch.empty(1, device=accum.device, dtype=torch.int8) |
| gpr = (k_in + group_size - 1) // group_size if has_pgt else 0 |
| _triton_apply_accumulated_flips_kernel[grid]( |
| packed, accum, |
| per_group_threshold if has_pgt else dummy, |
| total, accum_threshold, |
| k_in if has_pgt else 0, gpr, group_size if has_pgt else 0, |
| has_pgt, |
| BLOCK_T=block_t, |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_embed_fwd_kernel( |
| idx_ptr, packed_ptr, e_ptr, out_ptr, |
| NUM_IDX: tl.constexpr, DIM: tl.constexpr, |
| VOCAB: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, |
| ): |
| pid = tl.program_id(0) |
| offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) |
| offs_d = tl.arange(0, BLOCK_D) |
|
|
| idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32) |
|
|
| lin = idx[:, None] * DIM + offs_d[None, :] |
| pack_idx = lin // 5 |
| trit_pos = lin - pack_idx * 5 |
| packed = tl.load(packed_ptr + pack_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| sign = trit.to(tl.int32) - 1 |
|
|
| e_idx = idx[:, None] * GPR + offs_d[None, :] // GROUP_SIZE |
| e_val = tl.load(e_ptr + e_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.float32) |
| w = sign.to(tl.float32) * tl.exp2(e_val) |
| w = tl.where((offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), w, 0.0) |
|
|
| tl.store( |
| out_ptr + offs_b[:, None] * DIM + offs_d[None, :], |
| w, |
| mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_embed_bwd_accum_kernel( |
| idx_ptr, grad_ptr, accum_ptr, |
| NUM_IDX: tl.constexpr, DIM: tl.constexpr, |
| BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, |
| ): |
| pid = tl.program_id(0) |
| offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) |
| offs_d = tl.arange(0, BLOCK_D) |
| valid = (offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM) |
| idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32) |
| g = tl.load(grad_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0) |
| dst = idx[:, None] * DIM + offs_d[None, :] |
| tl.atomic_add(accum_ptr + dst, g, mask=valid) |
|
|
|
|
| @triton.jit |
| def _triton_ternary_embed_bwd_sign_kernel( |
| accum_ptr, sign_ptr, |
| VOCAB: tl.constexpr, DIM: tl.constexpr, |
| BLOCK_V: tl.constexpr, BLOCK_D: tl.constexpr, |
| ): |
| pid_v = tl.program_id(0) |
| offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V) |
| offs_d = tl.arange(0, BLOCK_D) |
| valid = (offs_v[:, None] < VOCAB) & (offs_d[None, :] < DIM) |
| acc = tl.load(accum_ptr + offs_v[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0) |
| sign_val = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int8) |
| tl.store(sign_ptr + offs_v[:, None] * DIM + offs_d[None, :], sign_val, mask=valid) |
|
|
|
|
| def _triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim): |
| flat_idx = indices.reshape(-1).contiguous().to(torch.int32) |
| grad_2d = grad_output.reshape(-1, dim).contiguous() |
| num_idx = flat_idx.shape[0] |
| accum = torch.zeros(vocab, dim, device=grad_output.device, dtype=torch.float32) |
| block_b = 64 |
| grid = (triton.cdiv(num_idx, block_b),) |
| _triton_ternary_embed_bwd_accum_kernel[grid]( |
| flat_idx, grad_2d, accum, |
| num_idx, dim, |
| BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim), |
| ) |
| sign_out = torch.empty(vocab, dim, device=grad_output.device, dtype=torch.int8) |
| block_v = 32 |
| grid2 = (triton.cdiv(vocab, block_v),) |
| _triton_ternary_embed_bwd_sign_kernel[grid2]( |
| accum, sign_out, |
| vocab, dim, |
| BLOCK_V=block_v, BLOCK_D=triton.next_power_of_2(dim), |
| ) |
| return sign_out |
|
|
|
|
| def _triton_ternary_embed(indices, packed, e, vocab, dim, group_size): |
| flat_idx = indices.reshape(-1).contiguous().to(torch.int32) |
| num_idx = flat_idx.shape[0] |
| out = torch.empty((num_idx, dim), device=indices.device, dtype=torch.float32) |
| block_b, block_d = 32, triton.next_power_of_2(dim) |
| gpr = ceil(dim / group_size) |
| grid = (triton.cdiv(num_idx, block_b),) |
| _triton_ternary_embed_fwd_kernel[grid]( |
| flat_idx, packed, e, out, |
| num_idx, dim, vocab, gpr, group_size, |
| BLOCK_B=block_b, BLOCK_D=block_d, |
| ) |
| return out.reshape(*indices.shape, dim) |
|
|
|
|
| class _TritonTernaryEmbedFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, indices, _dummy, module): |
| shape = tuple(module._T_shape.tolist()) |
| vocab, dim = shape |
| packed = module.T_packed.contiguous() |
| e = module.E.contiguous() |
| ctx.save_for_backward(indices, packed, e) |
| ctx.module = module |
| ctx.shape = shape |
| ctx.group_size = module.group_size |
| comp_name, _ = _COMPONENT_CONTEXT.get() |
| ctx.comp_name = comp_name |
| return _triton_ternary_embed(indices, packed, e, vocab, dim, module.group_size) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| indices, packed, e = ctx.saved_tensors |
| vocab, dim = ctx.shape |
| grad_2d = grad_output.reshape(-1, dim).contiguous() |
| comp_name = ctx.comp_name |
| has_corr = hasattr(ctx.module, "corr_accum") and hasattr(ctx.module, "_accumulate_corr_from_grad_sign") |
| if getattr(ctx.module, "_stream_backward_updates", True) and has_corr: |
| |
| grad_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim) |
| T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())).to(device=grad_sign.device) |
| signed = grad_sign.to(torch.int16) * T.to(torch.int16) |
| ctx.module._accumulate_corr_from_grad_sign(grad_sign) |
| ctx.module._streamed_bigint_backward = True |
| elif comp_name is not None: |
| setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)) |
| T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())) |
| setattr(ctx.module, f"_hook_T_{comp_name}", T.to(device=grad_2d.device)) |
| else: |
| ctx.module._hook_grad_T_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim) |
| T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())) |
| ctx.module._hook_T = T.to(device=grad_2d.device) |
| return None, None, None |
|
|
|
|
| class _TritonTernaryLinearFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, module): |
| shape = tuple(module._T_shape.tolist()) |
| n_out, k_in = shape |
| x_2d = x.reshape(-1, k_in).contiguous() |
| packed = module.T_packed.contiguous() |
| e = module.E.contiguous() |
| ctx.save_for_backward(x_2d, packed, e) |
| ctx.step_snapshot = int(module.step_counter.item()) |
| ctx.x_shape = x.shape |
| ctx.shape = shape |
| ctx.group_size = module.group_size |
| ctx.module = module |
| comp_name, _ = _COMPONENT_CONTEXT.get() |
| ctx.comp_name = comp_name |
| corr = module.corr_accum.contiguous() |
| step = module.step_counter.contiguous() |
| out = _triton_ternary_forward(x_2d, packed, e, corr, step, n_out, k_in, module.group_size) |
| return out.reshape(*x.shape[:-1], n_out) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| x_2d, packed, e = ctx.saved_tensors |
| n_out, k_in = ctx.shape |
| grad_2d = grad_output.reshape(-1, n_out).contiguous() |
| corr = ctx.module.corr_accum.contiguous() |
| step = torch.tensor([ctx.step_snapshot], device=e.device, dtype=torch.int64) |
| grad_x = _triton_ternary_grad_x( |
| grad_2d, packed, e, corr, step, x_2d.shape[0], n_out, k_in, ctx.group_size |
| ) |
| with torch.no_grad(): |
| if getattr(ctx.module, "_stream_backward_updates", True): |
| _, bwd_weight = _COMPONENT_CONTEXT.get() |
| corr_step = max(1, int(round(abs(float(bwd_weight))))) |
| if bwd_weight < 0: |
| corr_step = -corr_step |
| _triton_accumulate_corr_direct( |
| packed, grad_2d, x_2d, ctx.module.corr_accum, |
| n_out, k_in, ctx.group_size, corr_step=corr_step, |
| ) |
| ctx.module.step_counter.add_(abs(corr_step)) |
| ctx.module._streamed_bigint_backward = True |
| else: |
| grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in) |
| comp_name = ctx.comp_name |
| if comp_name is not None: |
| setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach()) |
| else: |
| ctx.module._hook_grad_T_sign = grad_sign.detach() |
| return grad_x.reshape(*ctx.x_shape), None |
|
|
|
|
| class _BigIntTernaryLinearFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, module): |
| shape = tuple(module._T_shape.tolist()) |
| n_out, k_in = shape |
| x_2d = x.reshape(-1, k_in).contiguous() |
| ctx.module = module |
| ctx.x_shape = x.shape |
| ctx.shape = shape |
| ctx.x_dtype = x.dtype |
| ctx.save_for_backward(x_2d) |
| with torch.no_grad(): |
| w_eff = module.dequantize().to(device=x.device, dtype=torch.float32) |
| out = F.linear(x_2d.float(), w_eff, module.bias.float() if module.bias is not None else None) |
| return out.reshape(*x.shape[:-1], n_out) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| (x_2d,) = ctx.saved_tensors |
| module = ctx.module |
| n_out, k_in = ctx.shape |
| grad_2d = grad_output.reshape(-1, n_out).contiguous() |
| with torch.no_grad(): |
| w_eff = module.dequantize().to(device=grad_2d.device, dtype=torch.float32) |
| grad_x = grad_2d.float() @ w_eff |
| grad_sign = (grad_2d.float().transpose(0, 1) @ x_2d.float()).sign().to(torch.int8) |
| module._accumulate_corr_from_grad_sign(grad_sign) |
| module._streamed_bigint_backward = True |
| return grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype), None |
|
|
|
|
| """ |
| Log-Space Group Scale Representation |
| |
| Convention (matching agents' Option B recommendation): |
| S = 2^E where S = scale, E = int8 log-space exponent |
| W_eff = T * 2^E |
| |
| Key log-space properties exploited: |
| Multiplication → addition: S1 * S2 = 2^(E1 + E2) |
| Division → subtraction: S1 / S2 = 2^(E1 - E2) |
| Dequant → integer shift: 2^E * T = T << E (for E >= 0) |
| |
| No IEEE floats in persistent state. E is stored as int8. |
| Ephemeral float only exists in autograd's computation graph. |
| """ |
|
|
| class TScaleType(IntEnum): |
| T4 = 4 |
| T6 = 6 |
| T8 = 8 |
| T16 = 16 |
| T32 = 32 |
| T64 = 64 |
| T96 = 96 |
|
|
| GROUP_SIZES = { |
| TScaleType.T4: 4, |
| TScaleType.T6: 6, |
| TScaleType.T8: 8, |
| TScaleType.T16: 16, |
| TScaleType.T32: 32, |
| TScaleType.T64: 64, |
| TScaleType.T96: 96, |
| } |
| TILE_SIZE = 384 |
|
|
| def _n_groups(shape, group_size): |
| out_dim, in_dim = shape |
| return out_dim * ceil(in_dim / group_size) |
|
|
| def _expand_E(E, shape, group_size): |
| out_dim, in_dim = shape |
| gpr = ceil(in_dim / group_size) |
| E_2d = E.view(out_dim, gpr) |
| E_exp = E_2d.repeat_interleave(group_size, dim=1) |
| if E_exp.shape[1] > in_dim: |
| E_exp = E_exp[:, :in_dim] |
| return E_exp |
|
|
| def _ternarize(x, threshold=0.05): |
| return x.sign() * (x.abs() > threshold).to(x.dtype) |
|
|
|
|
| def _scaled_init_threshold(threshold: float, init_std: float) -> float: |
| if init_std <= 0: |
| return threshold |
| return min(float(threshold), 0.5 * float(init_std)) |
|
|
| class TernaryScaleTensor(nn.Module): |
| def __init__( |
| self, |
| in_dim: int, |
| out_dim: int, |
| threshold: float = 0.05, |
| weight_init_std: float | None = None, |
| tscale_type: TScaleType = TScaleType.T32, |
| bias: bool = False, |
| ): |
| super().__init__() |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
| init_std = min(0.1, in_dim ** -0.5) if weight_init_std is None else float(weight_init_std) |
| init_threshold = _scaled_init_threshold(threshold, init_std) |
| self.threshold = init_threshold |
| self.tscale_type = tscale_type |
| self.group_size = GROUP_SIZES[tscale_type] |
| shape = (out_dim, in_dim) |
| n_grp = _n_groups(shape, self.group_size) |
|
|
| w_init = torch.randn(out_dim, in_dim) * init_std |
| T_init = _ternarize(w_init, init_threshold) |
| packed_T, T_shape, T_pad = pack_ternary(T_init) |
|
|
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
|
|
| gpr = ceil(in_dim / self.group_size) |
| total_in = gpr * self.group_size |
| padded = torch.zeros(out_dim, total_in) |
| abs_w = w_init.abs() |
| padded[:, :in_dim] = abs_w |
| grouped = padded.view(out_dim, gpr, self.group_size) |
| grp_means = grouped.mean(dim=2) |
| E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) |
| E_int = E_vals.log2().clamp(-128, 127).to(torch.int8) |
| self.register_buffer("E", E_int.flatten()) |
| self.register_buffer("corr_accum", torch.zeros_like(self.E, dtype=torch.int64)) |
| self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) |
|
|
| if bias: |
| self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32)) |
| else: |
| self.bias = None |
|
|
| def _get_T(self): |
| return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) |
|
|
| def _get_S(self): |
| gpr = ceil(self.in_dim / self.group_size) |
| e_adj = self.E.float() |
| if hasattr(self, "corr_accum") and hasattr(self, "step_counter"): |
| step = int(self.step_counter.item()) |
| if step > 0: |
| denom = max(step * self.group_size, 1) |
| e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength() |
| E_exp = _expand_E(e_adj, (self.out_dim, self.in_dim), self.group_size) |
| return torch.exp2(E_exp) |
|
|
| def _ensure_group_lr(self): |
| if not hasattr(self, "group_lr"): |
| self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8)) |
| elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device: |
| self.group_lr = torch.ones_like(self.E, dtype=torch.int8) |
| return self.group_lr |
|
|
| def precompile_kernels(self, M: int): |
| pass |
|
|
| def forward(self, x): |
| backend = _backend_preference() |
| if backend == "tilelang" and _HAS_TILELANG: |
| if torch.is_grad_enabled() and not _tilelang_training_enabled(): |
| raise RuntimeError( |
| "ARB_TERNARY_BACKEND='tilelang' is inference-only by default. " |
| "BigInt ternary training should use ARB_TERNARY_BACKEND='triton'. " |
| "Set ARB_TILELANG_TRAINING=1 only for experimental TileLang training." |
| ) |
| x_for_grad = x |
| if torch.is_grad_enabled() and not x.requires_grad: |
| x_for_grad = x.detach().requires_grad_(True) |
| N, K = tuple(self._T_shape.tolist()) |
| x_2d = x_for_grad.reshape(-1, K) |
| M = x_2d.shape[0] |
| try: |
| fwd_kernel = _get_kernel(M, N, K, self.group_size, "fwd") |
| y = _TernaryLinearFn.apply(x_for_grad, self, fwd_kernel) |
| if self.bias is not None: |
| y = y + self.bias.float() |
| return y |
| except Exception as e: |
| warnings.warn(f"TileLang forward failed for {self._T_shape.tolist()}: {e}") |
| if _HAS_TRITON: |
| backend = "triton" |
| else: |
| backend = "torch" |
| if x.is_cuda and _HAS_TRITON and backend in {"auto", "triton"}: |
| x_for_grad = x |
| if torch.is_grad_enabled() and not x.requires_grad: |
| x_for_grad = x.detach().requires_grad_(True) |
| y = _TritonTernaryLinearFn.apply(x_for_grad, self) |
| if self.bias is not None: |
| y = y + self.bias.float() |
| return y |
| if backend == "triton": |
| raise RuntimeError("ARB_TERNARY_BACKEND='triton' requested, but Triton is unavailable for this input.") |
| x_for_grad = x |
| if torch.is_grad_enabled() and not x.requires_grad: |
| x_for_grad = x.detach().requires_grad_(True) |
| return _BigIntTernaryLinearFn.apply(x_for_grad, self) |
|
|
| @torch.no_grad() |
| def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1): |
| shape = tuple(self._T_shape.tolist()) |
| out_dim, in_dim = shape |
| if tuple(grad_sign.shape) != shape: |
| return |
| T = self._get_T().to(device=grad_sign.device, dtype=torch.int16) |
| signed = grad_sign.to(torch.int16) * T |
| gpr = ceil(in_dim / self.group_size) |
| total_in = gpr * self.group_size |
| if total_in > in_dim: |
| signed = F.pad(signed, (0, total_in - in_dim)) |
| score = signed.view(out_dim, gpr, self.group_size).sum(dim=2, dtype=torch.int16) |
| self.corr_accum -= score.flatten().to(device=self.corr_accum.device, dtype=torch.int64) * int(corr_step) |
| self.step_counter += abs(int(corr_step)) |
|
|
| def ternary_step(self, lr=1, accum_threshold=None): |
| self._had_flip = False |
| if hasattr(self, "_hook_grad_T_sign"): |
| self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) |
| del self._hook_grad_T_sign |
|
|
| def update_E(self, lr=1, loss_signal=None): |
| has_dense_grad = hasattr(self, "_hook_grad_T_sign") |
| has_direct_grad = hasattr(self, "_hook_grad_2d") and hasattr(self, "_hook_x_2d") |
| if not has_dense_grad and not has_direct_grad: |
| return |
| if has_dense_grad: |
| self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) |
| del self._hook_grad_T_sign |
| else: |
| grad = self._hook_grad_2d.to(device=self.E.device, dtype=torch.float32) |
| x = self._hook_x_2d.to(device=self.E.device, dtype=torch.float32) |
| grad_sign = (grad.transpose(0, 1) @ x).sign().to(torch.int8) |
| self._accumulate_corr_from_grad_sign(grad_sign) |
| del self._hook_grad_2d |
| del self._hook_x_2d |
| if hasattr(self, "_hook_T"): |
| del self._hook_T |
|
|
| @property |
| def effective_bpw(self) -> float: |
| group_size = self.group_size |
| total = self._T_shape[0].item() * self._T_shape[1].item() |
| n_grp = _n_groups(tuple(self._T_shape.tolist()), group_size) |
| sign_bits = total * (8 / 5) |
| scale_bits = n_grp * 8.0 |
| corr_bits = n_grp * 64.0 |
| bias_bits = self.bias.numel() * 32.0 if self.bias is not None else 0.0 |
| return (sign_bits + scale_bits + corr_bits + bias_bits) / total |
|
|
| def dequantize(self) -> torch.Tensor: |
| T = self._get_T().float() |
| S = self._get_S() |
| return S * T |
|
|
| def tscale_to(self, tscale_type: TScaleType): |
| self.tscale_type = tscale_type |
| old_group_size = self.group_size |
| self.group_size = GROUP_SIZES[tscale_type] |
| shape = tuple(self._T_shape.tolist()) |
| out_dim, in_dim = shape |
| new_gpr = ceil(in_dim / self.group_size) |
| new_n_grp = out_dim * new_gpr |
| if self.E.shape[0] != new_n_grp: |
| T = self._get_T().float() |
| total_in = new_gpr * self.group_size |
| padded = torch.zeros(out_dim, total_in, device=self.T_packed.device) |
| abs_w = T.abs() |
| padded[:, :in_dim] = abs_w |
| grouped = padded.view(out_dim, new_gpr, self.group_size) |
| grp_means = grouped.mean(dim=2) |
| E_new = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) |
| E_int = E_new.log2().clamp(-128, 127).to(torch.int8) |
| self.E = E_int.flatten() |
| self.corr_accum = torch.zeros_like(self.E, dtype=torch.int64) |
| self.step_counter = torch.zeros(1, dtype=torch.int64, device=self.E.device) |
| return self |
|
|
| tscale_cast = tscale_to |
|
|
| def extra_repr(self) -> str: |
| return ( |
| f"in_dim={self.in_dim}, out_dim={self.out_dim}, " |
| f"tscale_type={self.tscale_type.name}, group_size={self.group_size}, " |
| f"effective_bpw={self.effective_bpw:.2f}" |
| ) |
|
|
| if _HAS_TRITON: |
|
|
| @triton.jit |
| def _triton_rmsnorm_fwd_kernel( |
| x_ptr, packed_ptr, e_ptr, out_ptr, |
| BATCH: tl.constexpr, DIM: tl.constexpr, |
| GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, |
| ): |
| pid_b = tl.program_id(0) |
| offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) |
| offs_d = tl.arange(0, BLOCK_D) |
|
|
| x = tl.load( |
| x_ptr + offs_b[:, None] * DIM + offs_d[None, :], |
| mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), |
| other=0.0, |
| ) |
| sq = x * x |
| msq = tl.sum(sq, axis=1, keep_dims=True) / DIM |
| rms = tl.sqrt(msq + 1e-5) |
| x_norm = x / rms |
|
|
| pack_idx = offs_d // 5 |
| trit_pos = offs_d - pack_idx * 5 |
| packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| sign = trit.to(tl.int32) - 1 |
|
|
| e_idx = offs_d // GROUP_SIZE |
| e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32) |
| w = sign.to(tl.float32) * tl.exp2(e_val) |
| w = tl.where(offs_d < DIM, w, 0.0) |
|
|
| out = x_norm * w[None, :] |
| tl.store( |
| out_ptr + offs_b[:, None] * DIM + offs_d[None, :], |
| out, |
| mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), |
| ) |
|
|
|
|
| @triton.jit |
| def _triton_rmsnorm_bwd_kernel( |
| grad_out_ptr, x_ptr, packed_ptr, e_ptr, |
| grad_x_ptr, |
| BATCH: tl.constexpr, DIM: tl.constexpr, |
| GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, |
| BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, |
| ): |
| pid_b = tl.program_id(0) |
| offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) |
| offs_d = tl.arange(0, BLOCK_D) |
|
|
| x = tl.load( |
| x_ptr + offs_b[:, None] * DIM + offs_d[None, :], |
| mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), |
| other=0.0, |
| ) |
| sq = x * x |
| msq = tl.sum(sq, axis=1, keep_dims=True) / DIM |
| rms = tl.sqrt(msq + 1e-5) |
| x_norm = x / rms |
|
|
| pack_idx = offs_d // 5 |
| trit_pos = offs_d - pack_idx * 5 |
| packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32) |
| divisor = tl.where( |
| trit_pos == 0, 1, |
| tl.where(trit_pos == 1, 3, |
| tl.where(trit_pos == 2, 9, |
| tl.where(trit_pos == 3, 27, 81))), |
| ) |
| trit = (packed // divisor) % 3 |
| sign = trit.to(tl.int32) - 1 |
|
|
| e_idx = offs_d // GROUP_SIZE |
| e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32) |
| w = sign.to(tl.float32) * tl.exp2(e_val) |
| w = tl.where(offs_d < DIM, w, 0.0) |
|
|
| dy = tl.load( |
| grad_out_ptr + offs_b[:, None] * DIM + offs_d[None, :], |
| mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), |
| other=0.0, |
| ) |
| dyw = dy * w[None, :] |
|
|
| c1 = tl.sum(x_norm * dyw, axis=1, keep_dims=True) / DIM |
| dx = (dyw - x_norm * c1) / rms |
|
|
| tl.store( |
| grad_x_ptr + offs_b[:, None] * DIM + offs_d[None, :], |
| dx, |
| mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), |
| ) |
|
|
|
|
| class _TritonRMSNormFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, module, packed, e, dim, group_size): |
| ctx.module = module |
| x_2d = x.reshape(-1, dim).contiguous() |
| batch = x_2d.shape[0] |
| out = torch.empty_like(x_2d) |
| block_b = 16 |
| grid = (triton.cdiv(batch, block_b),) |
| _triton_rmsnorm_fwd_kernel[grid]( |
| x_2d, packed, e, out, |
| batch, dim, ceil(dim / group_size), group_size, |
| BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim), |
| ) |
| ctx.save_for_backward(x_2d, packed, e) |
| ctx.dim = dim |
| ctx.group_size = group_size |
| comp_name, _ = _COMPONENT_CONTEXT.get() |
| ctx.comp_name = comp_name |
| return out.reshape(*x.shape) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| x_2d, packed, e = ctx.saved_tensors |
| dim = ctx.dim |
| group_size = ctx.group_size |
| grad_2d = grad_output.reshape(-1, dim).contiguous() |
| batch = grad_2d.shape[0] |
| grad_x = torch.empty_like(x_2d) |
| block_b = 16 |
| grid = (triton.cdiv(batch, block_b),) |
| _triton_rmsnorm_bwd_kernel[grid]( |
| grad_2d, x_2d, packed, e, grad_x, |
| batch, dim, ceil(dim / group_size), group_size, |
| BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim), |
| ) |
| return grad_x.reshape(*grad_output.shape), None, None, None, None, None |
|
|
|
|
| class TernaryRMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5, threshold=0.05, tscale_type=TScaleType.T64): |
| super().__init__() |
| self.dim = dim |
| self.eps = eps |
| self.threshold = threshold |
| self.tscale_type = tscale_type |
| self.group_size = GROUP_SIZES[tscale_type] |
| shape = (1, dim) |
| n_grp = _n_groups(shape, self.group_size) |
|
|
| w_init = torch.ones(1, dim) |
| T_init = _ternarize(w_init, threshold) |
| packed_T, T_shape, T_pad = pack_ternary(T_init) |
|
|
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
|
|
| gpr = ceil(dim / self.group_size) |
| total_in = gpr * self.group_size |
| padded = torch.zeros(1, total_in) |
| abs_w = w_init.abs() |
| padded[:, :dim] = abs_w |
| grouped = padded.view(1, gpr, self.group_size) |
| grp_means = grouped.mean(dim=2) |
| E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) |
| self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8)) |
| self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) |
| self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8)) |
|
|
| self.register_buffer("T_accum", torch.zeros(1, dim, dtype=torch.int8)) |
|
|
| def _ensure_E_accum(self): |
| if not hasattr(self, "E_accum"): |
| self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) |
| elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device: |
| self.E_accum = torch.zeros_like(self.E, dtype=torch.int8) |
| return self.E_accum |
|
|
| def _ensure_group_lr(self): |
| if not hasattr(self, "group_lr"): |
| self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8)) |
| elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device: |
| self.group_lr = torch.ones_like(self.E, dtype=torch.int8) |
| return self.group_lr |
|
|
| def _get_T(self): |
| return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())).squeeze(0) |
|
|
| def forward(self, x): |
| if x.is_cuda and _HAS_TRITON and self.dim <= _rmsnorm_triton_max_dim(): |
| return _TritonRMSNormFn.apply( |
| x, self, self.T_packed.contiguous(), self.E.contiguous(), |
| self.dim, self.group_size, |
| ) |
|
|
| inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
| if x.is_cuda: |
| |
| |
| |
| return x * inv_rms |
|
|
| T = self._get_T() |
| E_exp = _expand_E(self.E, tuple(self._T_shape.tolist()), self.group_size).squeeze(0) |
| S = torch.exp2(E_exp.float()) |
| weight = S * T.float() |
| return weight * (x * inv_rms) |
|
|
| def ternary_step(self, lr=1, accum_threshold=3): |
| pass |
|
|
| def update_E(self, lr=1, loss_signal=None): |
| pass |
|
|
| def extra_repr(self): |
| return f"dim={self.dim}, tscale_type={self.tscale_type.name}" |
|
|