import os import threading import warnings import torch import torch.nn as nn import torch.nn.functional as F from enum import IntEnum from math import ceil from ..converters.convert_to_ternary8 import pack_ternary, unpack_ternary _HAS_TILELANG = False try: import tilelang import tilelang.language as T _HAS_TILELANG = True except ImportError: pass _HAS_TRITON = False try: import triton import triton.language as tl _HAS_TRITON = True except ImportError: pass def _backend_preference() -> str: backend = os.environ.get("ARB_TERNARY_BACKEND", "auto").strip().lower() if backend not in {"auto", "tilelang", "triton", "torch"}: warnings.warn( f"Unknown ARB_TERNARY_BACKEND={backend!r}; falling back to auto.", RuntimeWarning, stacklevel=2, ) return "auto" return backend def _rmsnorm_triton_max_dim() -> int: raw = os.environ.get("ARB_RMSNORM_TRITON_MAX_DIM", "4096").strip() try: return max(0, int(raw)) except ValueError: warnings.warn( f"Invalid ARB_RMSNORM_TRITON_MAX_DIM={raw!r}; using 4096.", RuntimeWarning, stacklevel=2, ) return 4096 def _bigint_corr_strength() -> float: raw = os.environ.get("ARB_BIGINT_CORR_STRENGTH", "4.0").strip() try: return float(raw) except ValueError: warnings.warn( f"Invalid ARB_BIGINT_CORR_STRENGTH={raw!r}; using 4.0.", RuntimeWarning, stacklevel=2, ) return 4.0 class _ComponentContext: _local = threading.local() @classmethod def get(cls): val = getattr(cls._local, "current", None) if val is None: return None, 1.0 return val @classmethod def set(cls, name, weight=1.0): if name is None: cls._local.current = None else: cls._local.current = (name, weight) @classmethod def clear(cls): cls._local.current = None _COMPONENT_CONTEXT = _ComponentContext def _tilelang_training_enabled() -> bool: return os.environ.get("ARB_TILELANG_TRAINING", "0").strip().lower() in {"1", "true", "yes"} if _HAS_TILELANG: tilelang_jit = tilelang.jit(pass_configs={"tl.disable_warp_specialized": True}) def _ternary_fwd_kernel( M: int, N: int, K: int, group_size: int = 12, corr_strength: float = 4.0, block_M: int = 64, block_N: int = 64, block_K: int = 32, threads: int = 128, num_stages: int = 2, ): gpr = (K + group_size - 1) // group_size cs = corr_strength @T.prim_func def kernel( x: T.Tensor((M, K), "float16"), T_packed: T.Tensor((N * K + 4) // 5, "uint8"), E: T.Tensor((N * gpr), "int8"), corr_accum: T.Tensor((N * gpr), "int64"), step_counter: T.Tensor((1,), "int64"), output: T.Tensor((M, N), "float32"), ): steps = T.cast(step_counter[0], "int32") with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): x_shared = T.alloc_shared((block_M, block_K), dtype="float16") dq_shared = T.alloc_shared((block_N, block_K), dtype="float16") acc = T.alloc_fragment((block_M, block_N), dtype="float32") T.use_swizzle(10) T.clear(acc) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(x[bx * block_M, k * block_K], x_shared) for i, j in T.Parallel(block_N, block_K): i_glob = by * block_N + i j_glob = k * block_K + j if i_glob < N and j_glob < K: lin_idx = i_glob * K + j_glob pack_idx = lin_idx // 5 trit_pos = lin_idx % 5 packed_val = T.cast(T_packed[pack_idx], "int32") trit = T.if_then_else( trit_pos == 0, packed_val % 3, T.if_then_else(trit_pos == 1, (packed_val // 3) % 3, T.if_then_else(trit_pos == 2, (packed_val // 9) % 3, T.if_then_else(trit_pos == 3, (packed_val // 27) % 3, (packed_val // 81) % 3)))) sign_val = T.cast(trit, "int32") - 1 exp_idx = i_glob * gpr + j_glob // group_size exp_val = T.cast(E[exp_idx], "int32") ca = T.cast(corr_accum[exp_idx], "int32") den = T.max(steps * group_size, 1) mc = T.cast(ca, "float32") / T.cast(den, "float32") e_adj = T.cast(exp_val, "float32") + mc * cs ecl = T.min(T.max(e_adj, -14.0), 15.0) dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16") T.gemm(x_shared, dq_shared, acc, transpose_B=True) T.copy(acc, output[bx * block_M, by * block_N]) return tilelang_jit(kernel) def _ternary_grad_x_kernel( M: int, N: int, K: int, group_size: int = 12, corr_strength: float = 4.0, block_M: int = 64, block_N: int = 64, block_K: int = 32, threads: int = 128, num_stages: int = 2, ): gpr = (K + group_size - 1) // group_size cs = corr_strength @T.prim_func def kernel( grad_y: T.Tensor((M, N), "float16"), T_packed: T.Tensor((N * K + 4) // 5, "uint8"), E: T.Tensor((N * gpr), "int8"), corr_accum: T.Tensor((N * gpr), "int64"), step_counter: T.Tensor((1,), "int64"), output: T.Tensor((M, K), "float32"), ): steps = T.cast(step_counter[0], "int32") with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=threads) as (bx, by): gy_shared = T.alloc_shared((block_M, block_N), dtype="float16") dq_shared = T.alloc_shared((block_N, block_K), dtype="float16") acc = T.alloc_fragment((block_M, block_K), dtype="float32") T.use_swizzle(10) T.clear(acc) for n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages): T.copy(grad_y[bx * block_M, n * block_N], gy_shared) for i, j in T.Parallel(block_N, block_K): i_glob = n * block_N + i j_glob = by * block_K + j if i_glob < N and j_glob < K: lin_idx = i_glob * K + j_glob pack_idx = lin_idx // 5 trit_pos = lin_idx % 5 packed_val = T.cast(T_packed[pack_idx], "int32") trit = T.if_then_else( trit_pos == 0, packed_val % 3, T.if_then_else(trit_pos == 1, (packed_val // 3) % 3, T.if_then_else(trit_pos == 2, (packed_val // 9) % 3, T.if_then_else(trit_pos == 3, (packed_val // 27) % 3, (packed_val // 81) % 3)))) sign_val = T.cast(trit, "int32") - 1 exp_idx = i_glob * gpr + j_glob // group_size exp_val = T.cast(E[exp_idx], "int32") ca = T.cast(corr_accum[exp_idx], "int32") den = T.max(steps * group_size, 1) mc = T.cast(ca, "float32") / T.cast(den, "float32") e_adj = T.cast(exp_val, "float32") + mc * cs ecl = T.min(T.max(e_adj, -14.0), 15.0) dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16") T.gemm(gy_shared, dq_shared, acc) T.copy(acc, output[bx * block_M, by * block_K]) return tilelang_jit(kernel) _KERNEL_CACHE_FWD = {} _KERNEL_CACHE_GX = {} def _get_kernel(M, N, K, group_size, mode, corr_strength=4.0): cs = corr_strength if mode == "fwd": cache = _KERNEL_CACHE_FWD key = (M, N, K, group_size, cs) if key not in cache: cache[key] = _ternary_fwd_kernel(M, N, K, group_size, corr_strength=cs) return cache[key] elif mode == "grad_x": cache = _KERNEL_CACHE_GX key = (M, N, K, group_size) if key not in cache: cache[key] = _ternary_grad_x_kernel(M, N, K, group_size) return cache[key] raise ValueError(f"Unknown TileLang kernel mode: {mode}") def _get_grad_kernels(M, N, K, group_size): return _get_kernel(M, N, K, group_size, "grad_x") class _TernaryLinearFn(torch.autograd.Function): @staticmethod def forward(ctx, x, module, fwd_kernel): ctx.module = module T_packed = module.T_packed E = module.E shape = tuple(module._T_shape.tolist()) N, K = shape x_2d = x.reshape(-1, K).contiguous() ctx.group_size = module.group_size ctx.shape = shape ctx.x_shape = x.shape comp_name, _ = _COMPONENT_CONTEXT.get() ctx.comp_name = comp_name ctx.x_dtype = x.dtype has_corr = hasattr(module, "corr_accum") and hasattr(module, "step_counter") ctx.save_for_backward(x_2d, T_packed, E) ctx.has_corr = has_corr ctx.step_snapshot = int(module.step_counter.item()) if has_corr else 0 with torch.no_grad(): M = x_2d.shape[0] output = torch.empty(M, N, device=x.device, dtype=torch.float32) if has_corr: fwd_kernel(x_2d.half(), T_packed, E, module.corr_accum.contiguous(), module.step_counter.contiguous(), output) else: fwd_kernel(x_2d.half(), T_packed, E, torch.zeros(N * ((K + module.group_size - 1) // module.group_size), dtype=torch.int64, device=x.device), torch.zeros(1, dtype=torch.int64, device=x.device), output) return output.reshape(*x.shape[:-1], N) @staticmethod def backward(ctx, grad_output): x_2d, T_packed, E = ctx.saved_tensors group_size = ctx.group_size N, K = ctx.shape M = x_2d.shape[0] grad_2d = grad_output.reshape(-1, N).contiguous() if ctx.has_corr: corr_accum = ctx.module.corr_accum.contiguous() step_counter = torch.tensor([ctx.step_snapshot], dtype=torch.int64, device=x_2d.device) else: corr_accum = torch.zeros(N * ((K + group_size - 1) // group_size), dtype=torch.int64, device=x_2d.device) step_counter = torch.zeros(1, dtype=torch.int64, device=x_2d.device) grad_x_kernel = _get_grad_kernels(M, N, K, group_size) with torch.no_grad(): grad_x = torch.empty(M, K, device=x_2d.device, dtype=torch.float32) grad_x_kernel(grad_2d.half(), T_packed, E, corr_accum, step_counter, grad_x) comp_name = ctx.comp_name if _HAS_TRITON and ctx.has_corr and getattr(ctx.module, "_stream_backward_updates", True): bwd_name, bwd_weight = _COMPONENT_CONTEXT.get() if bwd_name is None: bwd_weight = 1.0 base_step = int(getattr(ctx.module, "_backward_t_accum_step", 1)) corr_step = max(1, int(round(abs(float(bwd_weight)) * base_step))) if bwd_weight < 0: corr_step = -corr_step _triton_accumulate_corr_direct( T_packed, grad_2d, x_2d, ctx.module.corr_accum, N, K, group_size, corr_step=corr_step, ) ctx.module.step_counter.add_(abs(corr_step)) ctx.module._streamed_bigint_backward = True elif _HAS_TRITON: grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, N, K) if comp_name is not None: setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach()) else: ctx.module._hook_grad_T_sign = grad_sign.detach() elif comp_name is not None: setattr(ctx.module, f"_hook_grad_2d_{comp_name}", grad_2d.detach()) setattr(ctx.module, f"_hook_x_2d_{comp_name}", x_2d.detach()) else: ctx.module._hook_grad_2d = grad_2d.detach() ctx.module._hook_x_2d = x_2d.detach() grad_x_reshaped = grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype) return grad_x_reshaped, None, None if _HAS_TRITON: @triton.jit def _triton_ternary_fwd_kernel( x_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, CORR_STRENGTH: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k0 in range(0, K, BLOCK_K): k = k0 + offs_k x = tl.load( x_ptr + offs_m[:, None] * K + k[None, :], mask=(offs_m[:, None] < M) & (k[None, :] < K), other=0.0, ) lin = offs_n[:, None] * K + k[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load( packed_ptr + pack_idx, mask=(offs_n[:, None] < N) & (k[None, :] < K), other=0, ).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 sign = trit.to(tl.int32) - 1 e_idx = offs_n[:, None] * GPR + k[None, :] // GROUP_SIZE e_val = tl.load( e_ptr + e_idx, mask=(offs_n[:, None] < N) & (k[None, :] < K), other=0, ).to(tl.float32) corr_val = tl.load( corr_ptr + e_idx, mask=(offs_n[:, None] < N) & (k[None, :] < K), other=0, ).to(tl.float32) step_val = tl.load(step_ptr).to(tl.float32) denom = tl.maximum(step_val * GROUP_SIZE, 1.0) e_adj = e_val + (corr_val / denom) * CORR_STRENGTH w = sign.to(tl.float32) * tl.exp2(e_adj) w = tl.where((offs_n[:, None] < N) & (k[None, :] < K), w, 0.0) acc += tl.dot(x, tl.trans(w)) tl.store( out_ptr + offs_m[:, None] * N + offs_n[None, :], acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), ) @triton.jit def _triton_ternary_grad_x_kernel( grad_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, CORR_STRENGTH: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_k = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) offs_n = tl.arange(0, BLOCK_N) acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) for n0 in range(0, N, BLOCK_N): n = n0 + offs_n grad = tl.load( grad_ptr + offs_m[:, None] * N + n[None, :], mask=(offs_m[:, None] < M) & (n[None, :] < N), other=0.0, ) lin = n[:, None] * K + offs_k[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load( packed_ptr + pack_idx, mask=(n[:, None] < N) & (offs_k[None, :] < K), other=0, ).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 sign = trit.to(tl.int32) - 1 e_idx = n[:, None] * GPR + offs_k[None, :] // GROUP_SIZE e_val = tl.load( e_ptr + e_idx, mask=(n[:, None] < N) & (offs_k[None, :] < K), other=0, ).to(tl.float32) corr_val = tl.load( corr_ptr + e_idx, mask=(n[:, None] < N) & (offs_k[None, :] < K), other=0, ).to(tl.float32) step_val = tl.load(step_ptr).to(tl.float32) denom = tl.maximum(step_val * GROUP_SIZE, 1.0) e_adj = e_val + (corr_val / denom) * CORR_STRENGTH w = sign.to(tl.float32) * tl.exp2(e_adj) w = tl.where((n[:, None] < N) & (offs_k[None, :] < K), w, 0.0) acc += tl.dot(grad, w) tl.store( out_ptr + offs_m[:, None] * K + offs_k[None, :], acc, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), ) @triton.jit def _triton_ternary_grad_sign_kernel( grad_ptr, x_ptr, sign_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_n = tl.program_id(0) pid_k = tl.program_id(1) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) offs_m = tl.arange(0, BLOCK_M) acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) for m0 in range(0, M, BLOCK_M): m = m0 + offs_m grad = tl.load( grad_ptr + m[:, None] * N + offs_n[None, :], mask=(m[:, None] < M) & (offs_n[None, :] < N), other=0.0, ) x = tl.load( x_ptr + m[:, None] * K + offs_k[None, :], mask=(m[:, None] < M) & (offs_k[None, :] < K), other=0.0, ) acc += tl.dot(tl.trans(grad), x, input_precision="ieee") sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)) tl.store( sign_ptr + offs_n[:, None] * K + offs_k[None, :], sign.to(tl.int8), mask=(offs_n[:, None] < N) & (offs_k[None, :] < K), ) @triton.jit def _triton_update_e_kernel( packed_ptr, grad_sign_ptr, e_ptr, e_accum_ptr, N: tl.constexpr, K: tl.constexpr, GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, E_ACCUM_THRESHOLD: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_G: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_n = tl.program_id(0) pid_g = tl.program_id(1) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_g = pid_g * BLOCK_G + tl.arange(0, BLOCK_G) offs_r = tl.arange(0, BLOCK_K) k = offs_g[:, None] * GROUP_SIZE + offs_r[None, :] valid_group = offs_g < GPR lin = offs_n[:, None, None] * K + k[None, :, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load( packed_ptr + pack_idx, mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K), other=0, ).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 ternary = trit.to(tl.int32) - 1 grad_sign = tl.load( grad_sign_ptr + offs_n[:, None, None] * K + k[None, :, :], mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K), other=0, ).to(tl.int32) contrib = grad_sign * ternary score = tl.sum(contrib, axis=2) delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0)) e_idx = offs_n[:, None] * GPR + offs_g[None, :] old_accum = tl.load( e_accum_ptr + e_idx, mask=(offs_n[:, None] < N) & valid_group[None, :], other=0, ).to(tl.int32) new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta)) step_up = new_accum >= E_ACCUM_THRESHOLD step_down = new_accum <= -E_ACCUM_THRESHOLD e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0)) stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD old_e = tl.load( e_ptr + e_idx, mask=(offs_n[:, None] < N) & valid_group[None, :], other=0, ).to(tl.int32) new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step)) tl.store( e_ptr + e_idx, new_e.to(tl.int8), mask=(offs_n[:, None] < N) & valid_group[None, :], ) tl.store( e_accum_ptr + e_idx, stored_accum.to(tl.int8), mask=(offs_n[:, None] < N) & valid_group[None, :], ) @triton.jit def _triton_ternary_step_kernel( packed_ptr, grad_sign_ptr, accum_ptr, per_group_threshold_ptr, TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr, T_ACCUM_STEP: tl.constexpr, K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, HAS_PER_GROUP_THRESHOLD: tl.constexpr, BLOCK_T: tl.constexpr, ): pack_idx = tl.program_id(0) offs_t = tl.arange(0, BLOCK_T) valid_trit = offs_t < 5 lin = pack_idx * 5 + offs_t valid = valid_trit & (lin < TOTAL) old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32) divisor = tl.where( offs_t == 0, 1, tl.where(offs_t == 1, 3, tl.where(offs_t == 2, 9, tl.where(offs_t == 3, 27, 81))), ) old_code = (old_packed // divisor) % 3 old_sign = old_code.to(tl.int32) - 1 grad_sign = tl.load(grad_sign_ptr + lin, mask=valid, other=0).to(tl.int32) old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP)) if HAS_PER_GROUP_THRESHOLD: n = lin // K k = lin - n * K g_idx = n * GPR + k // GROUP_SIZE threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32) else: threshold = ACCUM_THRESHOLD flip_up = new_accum > threshold flip_down = new_accum < -threshold did_flip = valid & (flip_up | flip_down) new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign)) stored_accum = tl.where(did_flip, 0, new_accum) tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid) new_code = tl.where(valid, new_sign + 1, 0) packed_val = tl.sum(new_code * divisor, axis=0) tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8)) @triton.jit def _triton_update_e_direct_kernel( packed_ptr, grad_ptr, x_ptr, e_ptr, e_accum_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, E_ACCUM_THRESHOLD: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_n = tl.program_id(0) pid_g = tl.program_id(1) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_r = tl.arange(0, BLOCK_K) k = pid_g * GROUP_SIZE + offs_r offs_m = tl.arange(0, BLOCK_M) acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) for m0 in range(0, M, BLOCK_M): m = m0 + offs_m grad = tl.load( grad_ptr + m[:, None] * N + offs_n[None, :], mask=(m[:, None] < M) & (offs_n[None, :] < N), other=0.0, ) x = tl.load( x_ptr + m[:, None] * K + k[None, :], mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), other=0.0, ) acc += tl.dot(tl.trans(grad), x, input_precision="ieee") grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) lin = offs_n[:, None] * K + k[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load( packed_ptr + pack_idx, mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), other=0, ).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 ternary = trit.to(tl.int32) - 1 contrib = tl.where( (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), grad_sign * ternary, 0, ) score = tl.sum(contrib, axis=1) delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0)) e_idx = offs_n * GPR + pid_g old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32) new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta)) step_up = new_accum >= E_ACCUM_THRESHOLD step_down = new_accum <= -E_ACCUM_THRESHOLD e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0)) stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD old_e = tl.load(e_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32) new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step)) tl.store(e_ptr + e_idx, new_e.to(tl.int8), mask=offs_n < N) tl.store(e_accum_ptr + e_idx, stored_accum.to(tl.int8), mask=offs_n < N) @triton.jit def _triton_ternary_step_direct_kernel( packed_ptr, grad_ptr, x_ptr, accum_ptr, per_group_threshold_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr, T_ACCUM_STEP: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, HAS_PER_GROUP_THRESHOLD: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr, ): pack_idx = tl.program_id(0) offs_t = tl.arange(0, BLOCK_T) lin = pack_idx * 5 + offs_t valid_trit = offs_t < 5 valid = valid_trit & (lin < TOTAL) n = lin // K k = lin - n * K offs_m = tl.arange(0, BLOCK_M) acc = tl.zeros((BLOCK_T,), dtype=tl.float32) for m0 in range(0, M, BLOCK_M): m = m0 + offs_m grad = tl.load( grad_ptr + m[:, None] * N + n[None, :], mask=(m[:, None] < M) & valid[None, :], other=0.0, ) x = tl.load( x_ptr + m[:, None] * K + k[None, :], mask=(m[:, None] < M) & valid[None, :], other=0.0, ) acc += tl.sum(grad * x, axis=0) grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32) divisor = tl.where( offs_t == 0, 1, tl.where(offs_t == 1, 3, tl.where(offs_t == 2, 9, tl.where(offs_t == 3, 27, 81))), ) old_code = (old_packed // divisor) % 3 old_sign = old_code.to(tl.int32) - 1 old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP)) if HAS_PER_GROUP_THRESHOLD: g_idx = n * GPR + k // GROUP_SIZE threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32) else: threshold = ACCUM_THRESHOLD flip_up = new_accum > threshold flip_down = new_accum < -threshold did_flip = valid & (flip_up | flip_down) new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign)) stored_accum = tl.where(did_flip, 0, new_accum) tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid) new_code = tl.where(valid, new_sign + 1, 0) packed_val = tl.sum(new_code * divisor, axis=0) tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8)) @triton.jit def _triton_accumulate_t_direct_kernel( grad_ptr, x_ptr, accum_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TOTAL: tl.constexpr, T_ACCUM_STEP: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr, ): pack_idx = tl.program_id(0) offs_t = tl.arange(0, BLOCK_T) lin = pack_idx * 5 + offs_t valid_trit = offs_t < 5 valid = valid_trit & (lin < TOTAL) n = lin // K k = lin - n * K offs_m = tl.arange(0, BLOCK_M) acc = tl.zeros((BLOCK_T,), dtype=tl.float32) for m0 in range(0, M, BLOCK_M): m = m0 + offs_m grad = tl.load( grad_ptr + m[:, None] * N + n[None, :], mask=(m[:, None] < M) & valid[None, :], other=0.0, ) x = tl.load( x_ptr + m[:, None] * K + k[None, :], mask=(m[:, None] < M) & valid[None, :], other=0.0, ) acc += tl.sum(grad * x, axis=0) grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP)) tl.store(accum_ptr + lin, new_accum.to(tl.int8), mask=valid) @triton.jit def _triton_accumulate_e_direct_kernel( packed_ptr, grad_ptr, x_ptr, e_accum_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, E_ACCUM_STEP: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_n = tl.program_id(0) pid_g = tl.program_id(1) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_r = tl.arange(0, BLOCK_K) k = pid_g * GROUP_SIZE + offs_r offs_m = tl.arange(0, BLOCK_M) acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) for m0 in range(0, M, BLOCK_M): m = m0 + offs_m grad = tl.load( grad_ptr + m[:, None] * N + offs_n[None, :], mask=(m[:, None] < M) & (offs_n[None, :] < N), other=0.0, ) x = tl.load( x_ptr + m[:, None] * K + k[None, :], mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), other=0.0, ) acc += tl.dot(tl.trans(grad), x, input_precision="ieee") grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) lin = offs_n[:, None] * K + k[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load( packed_ptr + pack_idx, mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), other=0, ).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 ternary = trit.to(tl.int32) - 1 contrib = tl.where( (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), grad_sign * ternary, 0, ) score = tl.sum(contrib, axis=1) delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0)) e_idx = offs_n * GPR + pid_g old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32) new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta * E_ACCUM_STEP)) tl.store(e_accum_ptr + e_idx, new_accum.to(tl.int8), mask=offs_n < N) @triton.jit def _triton_accumulate_corr_direct_kernel( packed_ptr, grad_ptr, x_ptr, corr_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, GROUP_SIZE: tl.constexpr, GPR: tl.constexpr, CORR_STEP: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_n = tl.program_id(0) pid_g = tl.program_id(1) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_r = tl.arange(0, BLOCK_K) k = pid_g * GROUP_SIZE + offs_r offs_m = tl.arange(0, BLOCK_M) acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) for m0 in range(0, M, BLOCK_M): m = m0 + offs_m grad = tl.load( grad_ptr + m[:, None] * N + offs_n[None, :], mask=(m[:, None] < M) & (offs_n[None, :] < N), other=0.0, ) x = tl.load( x_ptr + m[:, None] * K + k[None, :], mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), other=0.0, ) acc += tl.dot(tl.trans(grad), x, input_precision="ieee") grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32) lin = offs_n[:, None] * K + k[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load( packed_ptr + pack_idx, mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), other=0, ).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 ternary = trit.to(tl.int32) - 1 contrib = tl.where( (offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K), grad_sign * ternary, 0, ) score = tl.sum(contrib, axis=1) corr_idx = offs_n * GPR + pid_g old_corr = tl.load(corr_ptr + corr_idx, mask=offs_n < N, other=0).to(tl.int64) new_corr = old_corr - score.to(tl.int64) * CORR_STEP tl.store(corr_ptr + corr_idx, new_corr, mask=offs_n < N) @triton.jit def _triton_apply_accumulated_flips_kernel( packed_ptr, accum_ptr, per_group_threshold_ptr, TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr, K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, HAS_PER_GROUP_THRESHOLD: tl.constexpr, BLOCK_T: tl.constexpr, ): pack_idx = tl.program_id(0) offs_t = tl.arange(0, BLOCK_T) valid_trit = offs_t < 5 lin = pack_idx * 5 + offs_t valid = valid_trit & (lin < TOTAL) old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32) divisor = tl.where( offs_t == 0, 1, tl.where(offs_t == 1, 3, tl.where(offs_t == 2, 9, tl.where(offs_t == 3, 27, 81))), ) old_code = (old_packed // divisor) % 3 old_sign = old_code.to(tl.int32) - 1 old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32) if HAS_PER_GROUP_THRESHOLD: n = lin // K k = lin - n * K g_idx = n * GPR + k // GROUP_SIZE threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32) else: threshold = ACCUM_THRESHOLD flip_up = old_accum > threshold flip_down = old_accum < -threshold did_flip = valid & (flip_up | flip_down) new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign)) stored_accum = tl.where(did_flip, 0, old_accum) tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid) new_code = tl.where(valid, new_sign + 1, 0) packed_val = tl.sum(new_code * divisor, axis=0) tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8)) def _triton_ternary_forward(x_2d, packed, e, corr_accum, step_counter, n_out, k_in, group_size): block_m, block_n, block_k = 16, 16, 32 out = torch.empty((x_2d.shape[0], n_out), device=x_2d.device, dtype=torch.float32) grid = (triton.cdiv(x_2d.shape[0], block_m), triton.cdiv(n_out, block_n)) _triton_ternary_fwd_kernel[grid]( x_2d, packed, e, corr_accum, step_counter, out, x_2d.shape[0], n_out, k_in, ceil(k_in / group_size), group_size, _bigint_corr_strength(), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, ) return out def _triton_ternary_grad_x(grad_2d, packed, e, corr_accum, step_counter, m_rows, n_out, k_in, group_size): block_m, block_n, block_k = 16, 16, 32 out = torch.empty((m_rows, k_in), device=grad_2d.device, dtype=torch.float32) grid = (triton.cdiv(m_rows, block_m), triton.cdiv(k_in, block_k)) _triton_ternary_grad_x_kernel[grid]( grad_2d, packed, e, corr_accum, step_counter, out, m_rows, n_out, k_in, ceil(k_in / group_size), group_size, _bigint_corr_strength(), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, ) return out def _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in): block_m, block_n, block_k = 32, 16, 32 out = torch.empty((n_out, k_in), device=grad_2d.device, dtype=torch.int8) grid = (triton.cdiv(n_out, block_n), triton.cdiv(k_in, block_k)) _triton_ternary_grad_sign_kernel[grid]( grad_2d, x_2d, out, x_2d.shape[0], n_out, k_in, BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, ) return out def _triton_update_e(packed, grad_sign, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4): block_n, block_g = 8, 4 gpr = ceil(k_in / group_size) block_k = 1 << (group_size - 1).bit_length() grid = (triton.cdiv(n_out, block_n), triton.cdiv(gpr, block_g)) _triton_update_e_kernel[grid]( packed, grad_sign, e, e_accum, n_out, k_in, group_size, gpr, int(e_accum_threshold), BLOCK_N=block_n, BLOCK_G=block_g, BLOCK_K=block_k, ) def _triton_update_e_direct(packed, grad_2d, x_2d, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4): block_m, block_n = 32, 8 block_k = 1 << (group_size - 1).bit_length() gpr = ceil(k_in / group_size) grid = (triton.cdiv(n_out, block_n), gpr) _triton_update_e_direct_kernel[grid]( packed, grad_2d, x_2d, e, e_accum, x_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_threshold), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, ) def _triton_ternary_step(packed, grad_sign, accum, total, accum_threshold, t_accum_step=1, per_group_threshold=None, n_out=0, k_in=0, group_size=0): block_t = 8 grid = (triton.cdiv(total, 5),) has_pgt = per_group_threshold is not None dummy = torch.empty(1, device=accum.device, dtype=torch.int8) gpr = (k_in + group_size - 1) // group_size if has_pgt else 0 _triton_ternary_step_kernel[grid]( packed, grad_sign, accum, per_group_threshold if has_pgt else dummy, total, accum_threshold, int(t_accum_step), k_in if has_pgt else 0, gpr, group_size if has_pgt else 0, has_pgt, BLOCK_T=block_t, ) def _triton_ternary_step_direct(packed, grad_2d, x_2d, accum, n_out, k_in, total, accum_threshold, t_accum_step=1, per_group_threshold=None, group_size=0): block_m, block_t = 32, 8 grid = (triton.cdiv(total, 5),) has_pgt = per_group_threshold is not None dummy = torch.empty(1, device=accum.device, dtype=torch.int8) gpr = (k_in + group_size - 1) // group_size if has_pgt else 0 _triton_ternary_step_direct_kernel[grid]( packed, grad_2d, x_2d, accum, per_group_threshold if has_pgt else dummy, x_2d.shape[0], n_out, k_in, total, accum_threshold, int(t_accum_step), gpr, group_size if has_pgt else 0, has_pgt, BLOCK_M=block_m, BLOCK_T=block_t, ) def _triton_accumulate_direct(packed, grad_2d, x_2d, t_accum, e_accum, n_out, k_in, group_size, t_accum_step=1, e_accum_step=1, update_scales=True): block_m, block_t = 32, 8 total = n_out * k_in grid = (triton.cdiv(total, 5),) _triton_accumulate_t_direct_kernel[grid]( grad_2d, x_2d, t_accum, grad_2d.shape[0], n_out, k_in, total, int(t_accum_step), BLOCK_M=block_m, BLOCK_T=block_t, ) if update_scales and e_accum is not None: block_n = 8 block_k = 1 << (group_size - 1).bit_length() gpr = ceil(k_in / group_size) grid_e = (triton.cdiv(n_out, block_n), gpr) _triton_accumulate_e_direct_kernel[grid_e]( packed, grad_2d, x_2d, e_accum, grad_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_step), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, ) def _triton_accumulate_corr_direct(packed, grad_2d, x_2d, corr_accum, n_out, k_in, group_size, corr_step=1): block_m, block_n = 32, 8 block_k = 1 << (group_size - 1).bit_length() gpr = ceil(k_in / group_size) grid = (triton.cdiv(n_out, block_n), gpr) _triton_accumulate_corr_direct_kernel[grid]( packed, grad_2d, x_2d, corr_accum, grad_2d.shape[0], n_out, k_in, group_size, gpr, int(corr_step), BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k, ) def _triton_apply_accumulated_flips(packed, accum, total, accum_threshold, per_group_threshold=None, k_in=0, group_size=0): block_t = 8 grid = (triton.cdiv(total, 5),) has_pgt = per_group_threshold is not None dummy = torch.empty(1, device=accum.device, dtype=torch.int8) gpr = (k_in + group_size - 1) // group_size if has_pgt else 0 _triton_apply_accumulated_flips_kernel[grid]( packed, accum, per_group_threshold if has_pgt else dummy, total, accum_threshold, k_in if has_pgt else 0, gpr, group_size if has_pgt else 0, has_pgt, BLOCK_T=block_t, ) @triton.jit def _triton_ternary_embed_fwd_kernel( idx_ptr, packed_ptr, e_ptr, out_ptr, NUM_IDX: tl.constexpr, DIM: tl.constexpr, VOCAB: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, ): pid = tl.program_id(0) offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) offs_d = tl.arange(0, BLOCK_D) idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32) lin = idx[:, None] * DIM + offs_d[None, :] pack_idx = lin // 5 trit_pos = lin - pack_idx * 5 packed = tl.load(packed_ptr + pack_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 sign = trit.to(tl.int32) - 1 e_idx = idx[:, None] * GPR + offs_d[None, :] // GROUP_SIZE e_val = tl.load(e_ptr + e_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.float32) w = sign.to(tl.float32) * tl.exp2(e_val) w = tl.where((offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), w, 0.0) tl.store( out_ptr + offs_b[:, None] * DIM + offs_d[None, :], w, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), ) @triton.jit def _triton_ternary_embed_bwd_accum_kernel( idx_ptr, grad_ptr, accum_ptr, NUM_IDX: tl.constexpr, DIM: tl.constexpr, BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, ): pid = tl.program_id(0) offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) offs_d = tl.arange(0, BLOCK_D) valid = (offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM) idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32) g = tl.load(grad_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0) dst = idx[:, None] * DIM + offs_d[None, :] tl.atomic_add(accum_ptr + dst, g, mask=valid) @triton.jit def _triton_ternary_embed_bwd_sign_kernel( accum_ptr, sign_ptr, VOCAB: tl.constexpr, DIM: tl.constexpr, BLOCK_V: tl.constexpr, BLOCK_D: tl.constexpr, ): pid_v = tl.program_id(0) offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V) offs_d = tl.arange(0, BLOCK_D) valid = (offs_v[:, None] < VOCAB) & (offs_d[None, :] < DIM) acc = tl.load(accum_ptr + offs_v[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0) sign_val = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int8) tl.store(sign_ptr + offs_v[:, None] * DIM + offs_d[None, :], sign_val, mask=valid) def _triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim): flat_idx = indices.reshape(-1).contiguous().to(torch.int32) grad_2d = grad_output.reshape(-1, dim).contiguous() num_idx = flat_idx.shape[0] accum = torch.zeros(vocab, dim, device=grad_output.device, dtype=torch.float32) block_b = 64 grid = (triton.cdiv(num_idx, block_b),) _triton_ternary_embed_bwd_accum_kernel[grid]( flat_idx, grad_2d, accum, num_idx, dim, BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim), ) sign_out = torch.empty(vocab, dim, device=grad_output.device, dtype=torch.int8) block_v = 32 grid2 = (triton.cdiv(vocab, block_v),) _triton_ternary_embed_bwd_sign_kernel[grid2]( accum, sign_out, vocab, dim, BLOCK_V=block_v, BLOCK_D=triton.next_power_of_2(dim), ) return sign_out def _triton_ternary_embed(indices, packed, e, vocab, dim, group_size): flat_idx = indices.reshape(-1).contiguous().to(torch.int32) num_idx = flat_idx.shape[0] out = torch.empty((num_idx, dim), device=indices.device, dtype=torch.float32) block_b, block_d = 32, triton.next_power_of_2(dim) gpr = ceil(dim / group_size) grid = (triton.cdiv(num_idx, block_b),) _triton_ternary_embed_fwd_kernel[grid]( flat_idx, packed, e, out, num_idx, dim, vocab, gpr, group_size, BLOCK_B=block_b, BLOCK_D=block_d, ) return out.reshape(*indices.shape, dim) class _TritonTernaryEmbedFn(torch.autograd.Function): @staticmethod def forward(ctx, indices, _dummy, module): shape = tuple(module._T_shape.tolist()) vocab, dim = shape packed = module.T_packed.contiguous() e = module.E.contiguous() ctx.save_for_backward(indices, packed, e) ctx.module = module ctx.shape = shape ctx.group_size = module.group_size comp_name, _ = _COMPONENT_CONTEXT.get() ctx.comp_name = comp_name return _triton_ternary_embed(indices, packed, e, vocab, dim, module.group_size) @staticmethod def backward(ctx, grad_output): indices, packed, e = ctx.saved_tensors vocab, dim = ctx.shape grad_2d = grad_output.reshape(-1, dim).contiguous() comp_name = ctx.comp_name has_corr = hasattr(ctx.module, "corr_accum") and hasattr(ctx.module, "_accumulate_corr_from_grad_sign") if getattr(ctx.module, "_stream_backward_updates", True) and has_corr: # BigInt streaming: accumulate correlation directly grad_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim) T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())).to(device=grad_sign.device) signed = grad_sign.to(torch.int16) * T.to(torch.int16) ctx.module._accumulate_corr_from_grad_sign(grad_sign) ctx.module._streamed_bigint_backward = True elif comp_name is not None: setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)) T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())) setattr(ctx.module, f"_hook_T_{comp_name}", T.to(device=grad_2d.device)) else: ctx.module._hook_grad_T_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim) T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())) ctx.module._hook_T = T.to(device=grad_2d.device) return None, None, None class _TritonTernaryLinearFn(torch.autograd.Function): @staticmethod def forward(ctx, x, module): shape = tuple(module._T_shape.tolist()) n_out, k_in = shape x_2d = x.reshape(-1, k_in).contiguous() packed = module.T_packed.contiguous() e = module.E.contiguous() ctx.save_for_backward(x_2d, packed, e) ctx.step_snapshot = int(module.step_counter.item()) ctx.x_shape = x.shape ctx.shape = shape ctx.group_size = module.group_size ctx.module = module comp_name, _ = _COMPONENT_CONTEXT.get() ctx.comp_name = comp_name corr = module.corr_accum.contiguous() step = module.step_counter.contiguous() out = _triton_ternary_forward(x_2d, packed, e, corr, step, n_out, k_in, module.group_size) return out.reshape(*x.shape[:-1], n_out) @staticmethod def backward(ctx, grad_output): x_2d, packed, e = ctx.saved_tensors n_out, k_in = ctx.shape grad_2d = grad_output.reshape(-1, n_out).contiguous() corr = ctx.module.corr_accum.contiguous() step = torch.tensor([ctx.step_snapshot], device=e.device, dtype=torch.int64) grad_x = _triton_ternary_grad_x( grad_2d, packed, e, corr, step, x_2d.shape[0], n_out, k_in, ctx.group_size ) with torch.no_grad(): if getattr(ctx.module, "_stream_backward_updates", True): _, bwd_weight = _COMPONENT_CONTEXT.get() corr_step = max(1, int(round(abs(float(bwd_weight))))) if bwd_weight < 0: corr_step = -corr_step _triton_accumulate_corr_direct( packed, grad_2d, x_2d, ctx.module.corr_accum, n_out, k_in, ctx.group_size, corr_step=corr_step, ) ctx.module.step_counter.add_(abs(corr_step)) ctx.module._streamed_bigint_backward = True else: grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in) comp_name = ctx.comp_name if comp_name is not None: setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach()) else: ctx.module._hook_grad_T_sign = grad_sign.detach() return grad_x.reshape(*ctx.x_shape), None class _BigIntTernaryLinearFn(torch.autograd.Function): @staticmethod def forward(ctx, x, module): shape = tuple(module._T_shape.tolist()) n_out, k_in = shape x_2d = x.reshape(-1, k_in).contiguous() ctx.module = module ctx.x_shape = x.shape ctx.shape = shape ctx.x_dtype = x.dtype ctx.save_for_backward(x_2d) with torch.no_grad(): w_eff = module.dequantize().to(device=x.device, dtype=torch.float32) out = F.linear(x_2d.float(), w_eff, module.bias.float() if module.bias is not None else None) return out.reshape(*x.shape[:-1], n_out) @staticmethod def backward(ctx, grad_output): (x_2d,) = ctx.saved_tensors module = ctx.module n_out, k_in = ctx.shape grad_2d = grad_output.reshape(-1, n_out).contiguous() with torch.no_grad(): w_eff = module.dequantize().to(device=grad_2d.device, dtype=torch.float32) grad_x = grad_2d.float() @ w_eff grad_sign = (grad_2d.float().transpose(0, 1) @ x_2d.float()).sign().to(torch.int8) module._accumulate_corr_from_grad_sign(grad_sign) module._streamed_bigint_backward = True return grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype), None """ Log-Space Group Scale Representation Convention (matching agents' Option B recommendation): S = 2^E where S = scale, E = int8 log-space exponent W_eff = T * 2^E Key log-space properties exploited: Multiplication → addition: S1 * S2 = 2^(E1 + E2) Division → subtraction: S1 / S2 = 2^(E1 - E2) Dequant → integer shift: 2^E * T = T << E (for E >= 0) No IEEE floats in persistent state. E is stored as int8. Ephemeral float only exists in autograd's computation graph. """ class TScaleType(IntEnum): T4 = 4 T6 = 6 T8 = 8 T16 = 16 T32 = 32 T64 = 64 T96 = 96 GROUP_SIZES = { TScaleType.T4: 4, TScaleType.T6: 6, TScaleType.T8: 8, TScaleType.T16: 16, TScaleType.T32: 32, TScaleType.T64: 64, TScaleType.T96: 96, } TILE_SIZE = 384 def _n_groups(shape, group_size): out_dim, in_dim = shape return out_dim * ceil(in_dim / group_size) def _expand_E(E, shape, group_size): out_dim, in_dim = shape gpr = ceil(in_dim / group_size) E_2d = E.view(out_dim, gpr) E_exp = E_2d.repeat_interleave(group_size, dim=1) if E_exp.shape[1] > in_dim: E_exp = E_exp[:, :in_dim] return E_exp def _ternarize(x, threshold=0.05): return x.sign() * (x.abs() > threshold).to(x.dtype) def _scaled_init_threshold(threshold: float, init_std: float) -> float: if init_std <= 0: return threshold return min(float(threshold), 0.5 * float(init_std)) class TernaryScaleTensor(nn.Module): def __init__( self, in_dim: int, out_dim: int, threshold: float = 0.05, weight_init_std: float | None = None, tscale_type: TScaleType = TScaleType.T32, bias: bool = False, ): super().__init__() self.in_dim = in_dim self.out_dim = out_dim init_std = min(0.1, in_dim ** -0.5) if weight_init_std is None else float(weight_init_std) init_threshold = _scaled_init_threshold(threshold, init_std) self.threshold = init_threshold self.tscale_type = tscale_type self.group_size = GROUP_SIZES[tscale_type] shape = (out_dim, in_dim) n_grp = _n_groups(shape, self.group_size) w_init = torch.randn(out_dim, in_dim) * init_std T_init = _ternarize(w_init, init_threshold) packed_T, T_shape, T_pad = pack_ternary(T_init) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) gpr = ceil(in_dim / self.group_size) total_in = gpr * self.group_size padded = torch.zeros(out_dim, total_in) abs_w = w_init.abs() padded[:, :in_dim] = abs_w grouped = padded.view(out_dim, gpr, self.group_size) grp_means = grouped.mean(dim=2) E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) E_int = E_vals.log2().clamp(-128, 127).to(torch.int8) self.register_buffer("E", E_int.flatten()) self.register_buffer("corr_accum", torch.zeros_like(self.E, dtype=torch.int64)) self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) if bias: self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32)) else: self.bias = None def _get_T(self): return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) def _get_S(self): gpr = ceil(self.in_dim / self.group_size) e_adj = self.E.float() if hasattr(self, "corr_accum") and hasattr(self, "step_counter"): step = int(self.step_counter.item()) if step > 0: denom = max(step * self.group_size, 1) e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength() E_exp = _expand_E(e_adj, (self.out_dim, self.in_dim), self.group_size) return torch.exp2(E_exp) def _ensure_group_lr(self): if not hasattr(self, "group_lr"): self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8)) elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device: self.group_lr = torch.ones_like(self.E, dtype=torch.int8) return self.group_lr def precompile_kernels(self, M: int): pass def forward(self, x): backend = _backend_preference() if backend == "tilelang" and _HAS_TILELANG: if torch.is_grad_enabled() and not _tilelang_training_enabled(): raise RuntimeError( "ARB_TERNARY_BACKEND='tilelang' is inference-only by default. " "BigInt ternary training should use ARB_TERNARY_BACKEND='triton'. " "Set ARB_TILELANG_TRAINING=1 only for experimental TileLang training." ) x_for_grad = x if torch.is_grad_enabled() and not x.requires_grad: x_for_grad = x.detach().requires_grad_(True) N, K = tuple(self._T_shape.tolist()) x_2d = x_for_grad.reshape(-1, K) M = x_2d.shape[0] try: fwd_kernel = _get_kernel(M, N, K, self.group_size, "fwd") y = _TernaryLinearFn.apply(x_for_grad, self, fwd_kernel) if self.bias is not None: y = y + self.bias.float() return y except Exception as e: warnings.warn(f"TileLang forward failed for {self._T_shape.tolist()}: {e}") if _HAS_TRITON: backend = "triton" else: backend = "torch" if x.is_cuda and _HAS_TRITON and backend in {"auto", "triton"}: x_for_grad = x if torch.is_grad_enabled() and not x.requires_grad: x_for_grad = x.detach().requires_grad_(True) y = _TritonTernaryLinearFn.apply(x_for_grad, self) if self.bias is not None: y = y + self.bias.float() return y if backend == "triton": raise RuntimeError("ARB_TERNARY_BACKEND='triton' requested, but Triton is unavailable for this input.") x_for_grad = x if torch.is_grad_enabled() and not x.requires_grad: x_for_grad = x.detach().requires_grad_(True) return _BigIntTernaryLinearFn.apply(x_for_grad, self) @torch.no_grad() def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1): shape = tuple(self._T_shape.tolist()) out_dim, in_dim = shape if tuple(grad_sign.shape) != shape: return T = self._get_T().to(device=grad_sign.device, dtype=torch.int16) signed = grad_sign.to(torch.int16) * T gpr = ceil(in_dim / self.group_size) total_in = gpr * self.group_size if total_in > in_dim: signed = F.pad(signed, (0, total_in - in_dim)) score = signed.view(out_dim, gpr, self.group_size).sum(dim=2, dtype=torch.int16) self.corr_accum -= score.flatten().to(device=self.corr_accum.device, dtype=torch.int64) * int(corr_step) self.step_counter += abs(int(corr_step)) def ternary_step(self, lr=1, accum_threshold=None): self._had_flip = False if hasattr(self, "_hook_grad_T_sign"): self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) del self._hook_grad_T_sign def update_E(self, lr=1, loss_signal=None): has_dense_grad = hasattr(self, "_hook_grad_T_sign") has_direct_grad = hasattr(self, "_hook_grad_2d") and hasattr(self, "_hook_x_2d") if not has_dense_grad and not has_direct_grad: return if has_dense_grad: self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) del self._hook_grad_T_sign else: grad = self._hook_grad_2d.to(device=self.E.device, dtype=torch.float32) x = self._hook_x_2d.to(device=self.E.device, dtype=torch.float32) grad_sign = (grad.transpose(0, 1) @ x).sign().to(torch.int8) self._accumulate_corr_from_grad_sign(grad_sign) del self._hook_grad_2d del self._hook_x_2d if hasattr(self, "_hook_T"): del self._hook_T @property def effective_bpw(self) -> float: group_size = self.group_size total = self._T_shape[0].item() * self._T_shape[1].item() n_grp = _n_groups(tuple(self._T_shape.tolist()), group_size) sign_bits = total * (8 / 5) scale_bits = n_grp * 8.0 corr_bits = n_grp * 64.0 bias_bits = self.bias.numel() * 32.0 if self.bias is not None else 0.0 return (sign_bits + scale_bits + corr_bits + bias_bits) / total def dequantize(self) -> torch.Tensor: T = self._get_T().float() S = self._get_S() return S * T def tscale_to(self, tscale_type: TScaleType): self.tscale_type = tscale_type old_group_size = self.group_size self.group_size = GROUP_SIZES[tscale_type] shape = tuple(self._T_shape.tolist()) out_dim, in_dim = shape new_gpr = ceil(in_dim / self.group_size) new_n_grp = out_dim * new_gpr if self.E.shape[0] != new_n_grp: T = self._get_T().float() total_in = new_gpr * self.group_size padded = torch.zeros(out_dim, total_in, device=self.T_packed.device) abs_w = T.abs() padded[:, :in_dim] = abs_w grouped = padded.view(out_dim, new_gpr, self.group_size) grp_means = grouped.mean(dim=2) E_new = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) E_int = E_new.log2().clamp(-128, 127).to(torch.int8) self.E = E_int.flatten() self.corr_accum = torch.zeros_like(self.E, dtype=torch.int64) self.step_counter = torch.zeros(1, dtype=torch.int64, device=self.E.device) return self tscale_cast = tscale_to def extra_repr(self) -> str: return ( f"in_dim={self.in_dim}, out_dim={self.out_dim}, " f"tscale_type={self.tscale_type.name}, group_size={self.group_size}, " f"effective_bpw={self.effective_bpw:.2f}" ) if _HAS_TRITON: @triton.jit def _triton_rmsnorm_fwd_kernel( x_ptr, packed_ptr, e_ptr, out_ptr, BATCH: tl.constexpr, DIM: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, ): pid_b = tl.program_id(0) offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) offs_d = tl.arange(0, BLOCK_D) x = tl.load( x_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), other=0.0, ) sq = x * x msq = tl.sum(sq, axis=1, keep_dims=True) / DIM rms = tl.sqrt(msq + 1e-5) x_norm = x / rms pack_idx = offs_d // 5 trit_pos = offs_d - pack_idx * 5 packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 sign = trit.to(tl.int32) - 1 e_idx = offs_d // GROUP_SIZE e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32) w = sign.to(tl.float32) * tl.exp2(e_val) w = tl.where(offs_d < DIM, w, 0.0) out = x_norm * w[None, :] tl.store( out_ptr + offs_b[:, None] * DIM + offs_d[None, :], out, mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), ) @triton.jit def _triton_rmsnorm_bwd_kernel( grad_out_ptr, x_ptr, packed_ptr, e_ptr, grad_x_ptr, BATCH: tl.constexpr, DIM: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr, BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr, ): pid_b = tl.program_id(0) offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) offs_d = tl.arange(0, BLOCK_D) x = tl.load( x_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), other=0.0, ) sq = x * x msq = tl.sum(sq, axis=1, keep_dims=True) / DIM rms = tl.sqrt(msq + 1e-5) x_norm = x / rms pack_idx = offs_d // 5 trit_pos = offs_d - pack_idx * 5 packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32) divisor = tl.where( trit_pos == 0, 1, tl.where(trit_pos == 1, 3, tl.where(trit_pos == 2, 9, tl.where(trit_pos == 3, 27, 81))), ) trit = (packed // divisor) % 3 sign = trit.to(tl.int32) - 1 e_idx = offs_d // GROUP_SIZE e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32) w = sign.to(tl.float32) * tl.exp2(e_val) w = tl.where(offs_d < DIM, w, 0.0) dy = tl.load( grad_out_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), other=0.0, ) dyw = dy * w[None, :] c1 = tl.sum(x_norm * dyw, axis=1, keep_dims=True) / DIM dx = (dyw - x_norm * c1) / rms tl.store( grad_x_ptr + offs_b[:, None] * DIM + offs_d[None, :], dx, mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM), ) class _TritonRMSNormFn(torch.autograd.Function): @staticmethod def forward(ctx, x, module, packed, e, dim, group_size): ctx.module = module x_2d = x.reshape(-1, dim).contiguous() batch = x_2d.shape[0] out = torch.empty_like(x_2d) block_b = 16 grid = (triton.cdiv(batch, block_b),) _triton_rmsnorm_fwd_kernel[grid]( x_2d, packed, e, out, batch, dim, ceil(dim / group_size), group_size, BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim), ) ctx.save_for_backward(x_2d, packed, e) ctx.dim = dim ctx.group_size = group_size comp_name, _ = _COMPONENT_CONTEXT.get() ctx.comp_name = comp_name return out.reshape(*x.shape) @staticmethod def backward(ctx, grad_output): x_2d, packed, e = ctx.saved_tensors dim = ctx.dim group_size = ctx.group_size grad_2d = grad_output.reshape(-1, dim).contiguous() batch = grad_2d.shape[0] grad_x = torch.empty_like(x_2d) block_b = 16 grid = (triton.cdiv(batch, block_b),) _triton_rmsnorm_bwd_kernel[grid]( grad_2d, x_2d, packed, e, grad_x, batch, dim, ceil(dim / group_size), group_size, BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim), ) return grad_x.reshape(*grad_output.shape), None, None, None, None, None class TernaryRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5, threshold=0.05, tscale_type=TScaleType.T64): super().__init__() self.dim = dim self.eps = eps self.threshold = threshold self.tscale_type = tscale_type self.group_size = GROUP_SIZES[tscale_type] shape = (1, dim) n_grp = _n_groups(shape, self.group_size) w_init = torch.ones(1, dim) T_init = _ternarize(w_init, threshold) packed_T, T_shape, T_pad = pack_ternary(T_init) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) gpr = ceil(dim / self.group_size) total_in = gpr * self.group_size padded = torch.zeros(1, total_in) abs_w = w_init.abs() padded[:, :dim] = abs_w grouped = padded.view(1, gpr, self.group_size) grp_means = grouped.mean(dim=2) E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8)) self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8)) self.register_buffer("T_accum", torch.zeros(1, dim, dtype=torch.int8)) def _ensure_E_accum(self): if not hasattr(self, "E_accum"): self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8)) elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device: self.E_accum = torch.zeros_like(self.E, dtype=torch.int8) return self.E_accum def _ensure_group_lr(self): if not hasattr(self, "group_lr"): self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8)) elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device: self.group_lr = torch.ones_like(self.E, dtype=torch.int8) return self.group_lr def _get_T(self): return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())).squeeze(0) def forward(self, x): if x.is_cuda and _HAS_TRITON and self.dim <= _rmsnorm_triton_max_dim(): return _TritonRMSNormFn.apply( x, self, self.T_packed.contiguous(), self.E.contiguous(), self.dim, self.group_size, ) inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) if x.is_cuda: # TernaryRMSNorm is initialized as an identity scale and does not # train E/T. Avoid unpacking a full large-dim weight or launching # the high-register Triton backward kernel on 8GB GPUs. return x * inv_rms T = self._get_T() E_exp = _expand_E(self.E, tuple(self._T_shape.tolist()), self.group_size).squeeze(0) S = torch.exp2(E_exp.float()) weight = S * T.float() return weight * (x * inv_rms) def ternary_step(self, lr=1, accum_threshold=3): pass def update_E(self, lr=1, loss_signal=None): pass def extra_repr(self): return f"dim={self.dim}, tscale_type={self.tscale_type.name}"