ARBS / arbitor /kernel /ternary_scale.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import os
import threading
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from enum import IntEnum
from math import ceil
from ..converters.convert_to_ternary8 import pack_ternary, unpack_ternary
_HAS_TILELANG = False
try:
import tilelang
import tilelang.language as T
_HAS_TILELANG = True
except ImportError:
pass
_HAS_TRITON = False
try:
import triton
import triton.language as tl
_HAS_TRITON = True
except ImportError:
pass
def _backend_preference() -> str:
backend = os.environ.get("ARB_TERNARY_BACKEND", "auto").strip().lower()
if backend not in {"auto", "tilelang", "triton", "torch"}:
warnings.warn(
f"Unknown ARB_TERNARY_BACKEND={backend!r}; falling back to auto.",
RuntimeWarning,
stacklevel=2,
)
return "auto"
return backend
def _rmsnorm_triton_max_dim() -> int:
raw = os.environ.get("ARB_RMSNORM_TRITON_MAX_DIM", "4096").strip()
try:
return max(0, int(raw))
except ValueError:
warnings.warn(
f"Invalid ARB_RMSNORM_TRITON_MAX_DIM={raw!r}; using 4096.",
RuntimeWarning,
stacklevel=2,
)
return 4096
def _bigint_corr_strength() -> float:
raw = os.environ.get("ARB_BIGINT_CORR_STRENGTH", "4.0").strip()
try:
return float(raw)
except ValueError:
warnings.warn(
f"Invalid ARB_BIGINT_CORR_STRENGTH={raw!r}; using 4.0.",
RuntimeWarning,
stacklevel=2,
)
return 4.0
class _ComponentContext:
_local = threading.local()
@classmethod
def get(cls):
val = getattr(cls._local, "current", None)
if val is None:
return None, 1.0
return val
@classmethod
def set(cls, name, weight=1.0):
if name is None:
cls._local.current = None
else:
cls._local.current = (name, weight)
@classmethod
def clear(cls):
cls._local.current = None
_COMPONENT_CONTEXT = _ComponentContext
def _tilelang_training_enabled() -> bool:
return os.environ.get("ARB_TILELANG_TRAINING", "0").strip().lower() in {"1", "true", "yes"}
if _HAS_TILELANG:
tilelang_jit = tilelang.jit(pass_configs={"tl.disable_warp_specialized": True})
def _ternary_fwd_kernel(
M: int, N: int, K: int, group_size: int = 12,
corr_strength: float = 4.0,
block_M: int = 64, block_N: int = 64, block_K: int = 32,
threads: int = 128, num_stages: int = 2,
):
gpr = (K + group_size - 1) // group_size
cs = corr_strength
@T.prim_func
def kernel(
x: T.Tensor((M, K), "float16"),
T_packed: T.Tensor((N * K + 4) // 5, "uint8"),
E: T.Tensor((N * gpr), "int8"),
corr_accum: T.Tensor((N * gpr), "int64"),
step_counter: T.Tensor((1,), "int64"),
output: T.Tensor((M, N), "float32"),
):
steps = T.cast(step_counter[0], "int32")
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
x_shared = T.alloc_shared((block_M, block_K), dtype="float16")
dq_shared = T.alloc_shared((block_N, block_K), dtype="float16")
acc = T.alloc_fragment((block_M, block_N), dtype="float32")
T.use_swizzle(10)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(x[bx * block_M, k * block_K], x_shared)
for i, j in T.Parallel(block_N, block_K):
i_glob = by * block_N + i
j_glob = k * block_K + j
if i_glob < N and j_glob < K:
lin_idx = i_glob * K + j_glob
pack_idx = lin_idx // 5
trit_pos = lin_idx % 5
packed_val = T.cast(T_packed[pack_idx], "int32")
trit = T.if_then_else(
trit_pos == 0, packed_val % 3,
T.if_then_else(trit_pos == 1, (packed_val // 3) % 3,
T.if_then_else(trit_pos == 2, (packed_val // 9) % 3,
T.if_then_else(trit_pos == 3, (packed_val // 27) % 3,
(packed_val // 81) % 3))))
sign_val = T.cast(trit, "int32") - 1
exp_idx = i_glob * gpr + j_glob // group_size
exp_val = T.cast(E[exp_idx], "int32")
ca = T.cast(corr_accum[exp_idx], "int32")
den = T.max(steps * group_size, 1)
mc = T.cast(ca, "float32") / T.cast(den, "float32")
e_adj = T.cast(exp_val, "float32") + mc * cs
ecl = T.min(T.max(e_adj, -14.0), 15.0)
dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16")
T.gemm(x_shared, dq_shared, acc, transpose_B=True)
T.copy(acc, output[bx * block_M, by * block_N])
return tilelang_jit(kernel)
def _ternary_grad_x_kernel(
M: int, N: int, K: int, group_size: int = 12,
corr_strength: float = 4.0,
block_M: int = 64, block_N: int = 64, block_K: int = 32,
threads: int = 128, num_stages: int = 2,
):
gpr = (K + group_size - 1) // group_size
cs = corr_strength
@T.prim_func
def kernel(
grad_y: T.Tensor((M, N), "float16"),
T_packed: T.Tensor((N * K + 4) // 5, "uint8"),
E: T.Tensor((N * gpr), "int8"),
corr_accum: T.Tensor((N * gpr), "int64"),
step_counter: T.Tensor((1,), "int64"),
output: T.Tensor((M, K), "float32"),
):
steps = T.cast(step_counter[0], "int32")
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=threads) as (bx, by):
gy_shared = T.alloc_shared((block_M, block_N), dtype="float16")
dq_shared = T.alloc_shared((block_N, block_K), dtype="float16")
acc = T.alloc_fragment((block_M, block_K), dtype="float32")
T.use_swizzle(10)
T.clear(acc)
for n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages):
T.copy(grad_y[bx * block_M, n * block_N], gy_shared)
for i, j in T.Parallel(block_N, block_K):
i_glob = n * block_N + i
j_glob = by * block_K + j
if i_glob < N and j_glob < K:
lin_idx = i_glob * K + j_glob
pack_idx = lin_idx // 5
trit_pos = lin_idx % 5
packed_val = T.cast(T_packed[pack_idx], "int32")
trit = T.if_then_else(
trit_pos == 0, packed_val % 3,
T.if_then_else(trit_pos == 1, (packed_val // 3) % 3,
T.if_then_else(trit_pos == 2, (packed_val // 9) % 3,
T.if_then_else(trit_pos == 3, (packed_val // 27) % 3,
(packed_val // 81) % 3))))
sign_val = T.cast(trit, "int32") - 1
exp_idx = i_glob * gpr + j_glob // group_size
exp_val = T.cast(E[exp_idx], "int32")
ca = T.cast(corr_accum[exp_idx], "int32")
den = T.max(steps * group_size, 1)
mc = T.cast(ca, "float32") / T.cast(den, "float32")
e_adj = T.cast(exp_val, "float32") + mc * cs
ecl = T.min(T.max(e_adj, -14.0), 15.0)
dq_shared[i, j] = T.cast(T.exp2(ecl) * T.cast(sign_val, "float32"), "float16")
T.gemm(gy_shared, dq_shared, acc)
T.copy(acc, output[bx * block_M, by * block_K])
return tilelang_jit(kernel)
_KERNEL_CACHE_FWD = {}
_KERNEL_CACHE_GX = {}
def _get_kernel(M, N, K, group_size, mode, corr_strength=4.0):
cs = corr_strength
if mode == "fwd":
cache = _KERNEL_CACHE_FWD
key = (M, N, K, group_size, cs)
if key not in cache:
cache[key] = _ternary_fwd_kernel(M, N, K, group_size, corr_strength=cs)
return cache[key]
elif mode == "grad_x":
cache = _KERNEL_CACHE_GX
key = (M, N, K, group_size)
if key not in cache:
cache[key] = _ternary_grad_x_kernel(M, N, K, group_size)
return cache[key]
raise ValueError(f"Unknown TileLang kernel mode: {mode}")
def _get_grad_kernels(M, N, K, group_size):
return _get_kernel(M, N, K, group_size, "grad_x")
class _TernaryLinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, module, fwd_kernel):
ctx.module = module
T_packed = module.T_packed
E = module.E
shape = tuple(module._T_shape.tolist())
N, K = shape
x_2d = x.reshape(-1, K).contiguous()
ctx.group_size = module.group_size
ctx.shape = shape
ctx.x_shape = x.shape
comp_name, _ = _COMPONENT_CONTEXT.get()
ctx.comp_name = comp_name
ctx.x_dtype = x.dtype
has_corr = hasattr(module, "corr_accum") and hasattr(module, "step_counter")
ctx.save_for_backward(x_2d, T_packed, E)
ctx.has_corr = has_corr
ctx.step_snapshot = int(module.step_counter.item()) if has_corr else 0
with torch.no_grad():
M = x_2d.shape[0]
output = torch.empty(M, N, device=x.device, dtype=torch.float32)
if has_corr:
fwd_kernel(x_2d.half(), T_packed, E,
module.corr_accum.contiguous(),
module.step_counter.contiguous(), output)
else:
fwd_kernel(x_2d.half(), T_packed, E,
torch.zeros(N * ((K + module.group_size - 1) // module.group_size),
dtype=torch.int64, device=x.device),
torch.zeros(1, dtype=torch.int64, device=x.device), output)
return output.reshape(*x.shape[:-1], N)
@staticmethod
def backward(ctx, grad_output):
x_2d, T_packed, E = ctx.saved_tensors
group_size = ctx.group_size
N, K = ctx.shape
M = x_2d.shape[0]
grad_2d = grad_output.reshape(-1, N).contiguous()
if ctx.has_corr:
corr_accum = ctx.module.corr_accum.contiguous()
step_counter = torch.tensor([ctx.step_snapshot], dtype=torch.int64, device=x_2d.device)
else:
corr_accum = torch.zeros(N * ((K + group_size - 1) // group_size),
dtype=torch.int64, device=x_2d.device)
step_counter = torch.zeros(1, dtype=torch.int64, device=x_2d.device)
grad_x_kernel = _get_grad_kernels(M, N, K, group_size)
with torch.no_grad():
grad_x = torch.empty(M, K, device=x_2d.device, dtype=torch.float32)
grad_x_kernel(grad_2d.half(), T_packed, E, corr_accum, step_counter, grad_x)
comp_name = ctx.comp_name
if _HAS_TRITON and ctx.has_corr and getattr(ctx.module, "_stream_backward_updates", True):
bwd_name, bwd_weight = _COMPONENT_CONTEXT.get()
if bwd_name is None:
bwd_weight = 1.0
base_step = int(getattr(ctx.module, "_backward_t_accum_step", 1))
corr_step = max(1, int(round(abs(float(bwd_weight)) * base_step)))
if bwd_weight < 0:
corr_step = -corr_step
_triton_accumulate_corr_direct(
T_packed, grad_2d, x_2d, ctx.module.corr_accum,
N, K, group_size, corr_step=corr_step,
)
ctx.module.step_counter.add_(abs(corr_step))
ctx.module._streamed_bigint_backward = True
elif _HAS_TRITON:
grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, N, K)
if comp_name is not None:
setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach())
else:
ctx.module._hook_grad_T_sign = grad_sign.detach()
elif comp_name is not None:
setattr(ctx.module, f"_hook_grad_2d_{comp_name}", grad_2d.detach())
setattr(ctx.module, f"_hook_x_2d_{comp_name}", x_2d.detach())
else:
ctx.module._hook_grad_2d = grad_2d.detach()
ctx.module._hook_x_2d = x_2d.detach()
grad_x_reshaped = grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype)
return grad_x_reshaped, None, None
if _HAS_TRITON:
@triton.jit
def _triton_ternary_fwd_kernel(
x_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
CORR_STRENGTH: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k0 in range(0, K, BLOCK_K):
k = k0 + offs_k
x = tl.load(
x_ptr + offs_m[:, None] * K + k[None, :],
mask=(offs_m[:, None] < M) & (k[None, :] < K),
other=0.0,
)
lin = offs_n[:, None] * K + k[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(
packed_ptr + pack_idx,
mask=(offs_n[:, None] < N) & (k[None, :] < K),
other=0,
).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
sign = trit.to(tl.int32) - 1
e_idx = offs_n[:, None] * GPR + k[None, :] // GROUP_SIZE
e_val = tl.load(
e_ptr + e_idx,
mask=(offs_n[:, None] < N) & (k[None, :] < K),
other=0,
).to(tl.float32)
corr_val = tl.load(
corr_ptr + e_idx,
mask=(offs_n[:, None] < N) & (k[None, :] < K),
other=0,
).to(tl.float32)
step_val = tl.load(step_ptr).to(tl.float32)
denom = tl.maximum(step_val * GROUP_SIZE, 1.0)
e_adj = e_val + (corr_val / denom) * CORR_STRENGTH
w = sign.to(tl.float32) * tl.exp2(e_adj)
w = tl.where((offs_n[:, None] < N) & (k[None, :] < K), w, 0.0)
acc += tl.dot(x, tl.trans(w))
tl.store(
out_ptr + offs_m[:, None] * N + offs_n[None, :],
acc,
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
)
@triton.jit
def _triton_ternary_grad_x_kernel(
grad_ptr, packed_ptr, e_ptr, corr_ptr, step_ptr, out_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
CORR_STRENGTH: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
for n0 in range(0, N, BLOCK_N):
n = n0 + offs_n
grad = tl.load(
grad_ptr + offs_m[:, None] * N + n[None, :],
mask=(offs_m[:, None] < M) & (n[None, :] < N),
other=0.0,
)
lin = n[:, None] * K + offs_k[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(
packed_ptr + pack_idx,
mask=(n[:, None] < N) & (offs_k[None, :] < K),
other=0,
).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
sign = trit.to(tl.int32) - 1
e_idx = n[:, None] * GPR + offs_k[None, :] // GROUP_SIZE
e_val = tl.load(
e_ptr + e_idx,
mask=(n[:, None] < N) & (offs_k[None, :] < K),
other=0,
).to(tl.float32)
corr_val = tl.load(
corr_ptr + e_idx,
mask=(n[:, None] < N) & (offs_k[None, :] < K),
other=0,
).to(tl.float32)
step_val = tl.load(step_ptr).to(tl.float32)
denom = tl.maximum(step_val * GROUP_SIZE, 1.0)
e_adj = e_val + (corr_val / denom) * CORR_STRENGTH
w = sign.to(tl.float32) * tl.exp2(e_adj)
w = tl.where((n[:, None] < N) & (offs_k[None, :] < K), w, 0.0)
acc += tl.dot(grad, w)
tl.store(
out_ptr + offs_m[:, None] * K + offs_k[None, :],
acc,
mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
)
@triton.jit
def _triton_ternary_grad_sign_kernel(
grad_ptr, x_ptr, sign_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_k = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
offs_m = tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
for m0 in range(0, M, BLOCK_M):
m = m0 + offs_m
grad = tl.load(
grad_ptr + m[:, None] * N + offs_n[None, :],
mask=(m[:, None] < M) & (offs_n[None, :] < N),
other=0.0,
)
x = tl.load(
x_ptr + m[:, None] * K + offs_k[None, :],
mask=(m[:, None] < M) & (offs_k[None, :] < K),
other=0.0,
)
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0))
tl.store(
sign_ptr + offs_n[:, None] * K + offs_k[None, :],
sign.to(tl.int8),
mask=(offs_n[:, None] < N) & (offs_k[None, :] < K),
)
@triton.jit
def _triton_update_e_kernel(
packed_ptr, grad_sign_ptr, e_ptr, e_accum_ptr,
N: tl.constexpr, K: tl.constexpr,
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
E_ACCUM_THRESHOLD: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_G: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_g = pid_g * BLOCK_G + tl.arange(0, BLOCK_G)
offs_r = tl.arange(0, BLOCK_K)
k = offs_g[:, None] * GROUP_SIZE + offs_r[None, :]
valid_group = offs_g < GPR
lin = offs_n[:, None, None] * K + k[None, :, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(
packed_ptr + pack_idx,
mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K),
other=0,
).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
ternary = trit.to(tl.int32) - 1
grad_sign = tl.load(
grad_sign_ptr + offs_n[:, None, None] * K + k[None, :, :],
mask=(offs_n[:, None, None] < N) & valid_group[None, :, None] & (offs_r[None, None, :] < GROUP_SIZE) & (k[None, :, :] < K),
other=0,
).to(tl.int32)
contrib = grad_sign * ternary
score = tl.sum(contrib, axis=2)
delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
e_idx = offs_n[:, None] * GPR + offs_g[None, :]
old_accum = tl.load(
e_accum_ptr + e_idx,
mask=(offs_n[:, None] < N) & valid_group[None, :],
other=0,
).to(tl.int32)
new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta))
step_up = new_accum >= E_ACCUM_THRESHOLD
step_down = new_accum <= -E_ACCUM_THRESHOLD
e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0))
stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD
old_e = tl.load(
e_ptr + e_idx,
mask=(offs_n[:, None] < N) & valid_group[None, :],
other=0,
).to(tl.int32)
new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step))
tl.store(
e_ptr + e_idx,
new_e.to(tl.int8),
mask=(offs_n[:, None] < N) & valid_group[None, :],
)
tl.store(
e_accum_ptr + e_idx,
stored_accum.to(tl.int8),
mask=(offs_n[:, None] < N) & valid_group[None, :],
)
@triton.jit
def _triton_ternary_step_kernel(
packed_ptr, grad_sign_ptr, accum_ptr, per_group_threshold_ptr,
TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
T_ACCUM_STEP: tl.constexpr,
K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
HAS_PER_GROUP_THRESHOLD: tl.constexpr,
BLOCK_T: tl.constexpr,
):
pack_idx = tl.program_id(0)
offs_t = tl.arange(0, BLOCK_T)
valid_trit = offs_t < 5
lin = pack_idx * 5 + offs_t
valid = valid_trit & (lin < TOTAL)
old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
divisor = tl.where(
offs_t == 0, 1,
tl.where(offs_t == 1, 3,
tl.where(offs_t == 2, 9,
tl.where(offs_t == 3, 27, 81))),
)
old_code = (old_packed // divisor) % 3
old_sign = old_code.to(tl.int32) - 1
grad_sign = tl.load(grad_sign_ptr + lin, mask=valid, other=0).to(tl.int32)
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
if HAS_PER_GROUP_THRESHOLD:
n = lin // K
k = lin - n * K
g_idx = n * GPR + k // GROUP_SIZE
threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
else:
threshold = ACCUM_THRESHOLD
flip_up = new_accum > threshold
flip_down = new_accum < -threshold
did_flip = valid & (flip_up | flip_down)
new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
stored_accum = tl.where(did_flip, 0, new_accum)
tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
new_code = tl.where(valid, new_sign + 1, 0)
packed_val = tl.sum(new_code * divisor, axis=0)
tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
@triton.jit
def _triton_update_e_direct_kernel(
packed_ptr, grad_ptr, x_ptr, e_ptr, e_accum_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
E_ACCUM_THRESHOLD: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_r = tl.arange(0, BLOCK_K)
k = pid_g * GROUP_SIZE + offs_r
offs_m = tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
for m0 in range(0, M, BLOCK_M):
m = m0 + offs_m
grad = tl.load(
grad_ptr + m[:, None] * N + offs_n[None, :],
mask=(m[:, None] < M) & (offs_n[None, :] < N),
other=0.0,
)
x = tl.load(
x_ptr + m[:, None] * K + k[None, :],
mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
other=0.0,
)
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
lin = offs_n[:, None] * K + k[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(
packed_ptr + pack_idx,
mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
other=0,
).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
ternary = trit.to(tl.int32) - 1
contrib = tl.where(
(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
grad_sign * ternary,
0,
)
score = tl.sum(contrib, axis=1)
delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
e_idx = offs_n * GPR + pid_g
old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta))
step_up = new_accum >= E_ACCUM_THRESHOLD
step_down = new_accum <= -E_ACCUM_THRESHOLD
e_step = tl.where(step_up, 1, tl.where(step_down, -1, 0))
stored_accum = new_accum - e_step * E_ACCUM_THRESHOLD
old_e = tl.load(e_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
new_e = tl.minimum(127, tl.maximum(-128, old_e + e_step))
tl.store(e_ptr + e_idx, new_e.to(tl.int8), mask=offs_n < N)
tl.store(e_accum_ptr + e_idx, stored_accum.to(tl.int8), mask=offs_n < N)
@triton.jit
def _triton_ternary_step_direct_kernel(
packed_ptr, grad_ptr, x_ptr, accum_ptr, per_group_threshold_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
T_ACCUM_STEP: tl.constexpr,
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
HAS_PER_GROUP_THRESHOLD: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr,
):
pack_idx = tl.program_id(0)
offs_t = tl.arange(0, BLOCK_T)
lin = pack_idx * 5 + offs_t
valid_trit = offs_t < 5
valid = valid_trit & (lin < TOTAL)
n = lin // K
k = lin - n * K
offs_m = tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)
for m0 in range(0, M, BLOCK_M):
m = m0 + offs_m
grad = tl.load(
grad_ptr + m[:, None] * N + n[None, :],
mask=(m[:, None] < M) & valid[None, :],
other=0.0,
)
x = tl.load(
x_ptr + m[:, None] * K + k[None, :],
mask=(m[:, None] < M) & valid[None, :],
other=0.0,
)
acc += tl.sum(grad * x, axis=0)
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
divisor = tl.where(
offs_t == 0, 1,
tl.where(offs_t == 1, 3,
tl.where(offs_t == 2, 9,
tl.where(offs_t == 3, 27, 81))),
)
old_code = (old_packed // divisor) % 3
old_sign = old_code.to(tl.int32) - 1
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
if HAS_PER_GROUP_THRESHOLD:
g_idx = n * GPR + k // GROUP_SIZE
threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
else:
threshold = ACCUM_THRESHOLD
flip_up = new_accum > threshold
flip_down = new_accum < -threshold
did_flip = valid & (flip_up | flip_down)
new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
stored_accum = tl.where(did_flip, 0, new_accum)
tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
new_code = tl.where(valid, new_sign + 1, 0)
packed_val = tl.sum(new_code * divisor, axis=0)
tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
@triton.jit
def _triton_accumulate_t_direct_kernel(
grad_ptr, x_ptr, accum_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
TOTAL: tl.constexpr, T_ACCUM_STEP: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_T: tl.constexpr,
):
pack_idx = tl.program_id(0)
offs_t = tl.arange(0, BLOCK_T)
lin = pack_idx * 5 + offs_t
valid_trit = offs_t < 5
valid = valid_trit & (lin < TOTAL)
n = lin // K
k = lin - n * K
offs_m = tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_T,), dtype=tl.float32)
for m0 in range(0, M, BLOCK_M):
m = m0 + offs_m
grad = tl.load(
grad_ptr + m[:, None] * N + n[None, :],
mask=(m[:, None] < M) & valid[None, :],
other=0.0,
)
x = tl.load(
x_ptr + m[:, None] * K + k[None, :],
mask=(m[:, None] < M) & valid[None, :],
other=0.0,
)
acc += tl.sum(grad * x, axis=0)
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
new_accum = tl.minimum(127, tl.maximum(-128, old_accum - grad_sign * T_ACCUM_STEP))
tl.store(accum_ptr + lin, new_accum.to(tl.int8), mask=valid)
@triton.jit
def _triton_accumulate_e_direct_kernel(
packed_ptr, grad_ptr, x_ptr, e_accum_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
E_ACCUM_STEP: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_r = tl.arange(0, BLOCK_K)
k = pid_g * GROUP_SIZE + offs_r
offs_m = tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
for m0 in range(0, M, BLOCK_M):
m = m0 + offs_m
grad = tl.load(
grad_ptr + m[:, None] * N + offs_n[None, :],
mask=(m[:, None] < M) & (offs_n[None, :] < N),
other=0.0,
)
x = tl.load(
x_ptr + m[:, None] * K + k[None, :],
mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
other=0.0,
)
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
lin = offs_n[:, None] * K + k[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(
packed_ptr + pack_idx,
mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
other=0,
).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
ternary = trit.to(tl.int32) - 1
contrib = tl.where(
(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
grad_sign * ternary,
0,
)
score = tl.sum(contrib, axis=1)
delta = tl.where(score > 0, -1, tl.where(score < 0, 1, 0))
e_idx = offs_n * GPR + pid_g
old_accum = tl.load(e_accum_ptr + e_idx, mask=offs_n < N, other=0).to(tl.int32)
new_accum = tl.minimum(127, tl.maximum(-128, old_accum + delta * E_ACCUM_STEP))
tl.store(e_accum_ptr + e_idx, new_accum.to(tl.int8), mask=offs_n < N)
@triton.jit
def _triton_accumulate_corr_direct_kernel(
packed_ptr, grad_ptr, x_ptr, corr_ptr,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
GROUP_SIZE: tl.constexpr, GPR: tl.constexpr,
CORR_STEP: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_g = tl.program_id(1)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_r = tl.arange(0, BLOCK_K)
k = pid_g * GROUP_SIZE + offs_r
offs_m = tl.arange(0, BLOCK_M)
acc = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
for m0 in range(0, M, BLOCK_M):
m = m0 + offs_m
grad = tl.load(
grad_ptr + m[:, None] * N + offs_n[None, :],
mask=(m[:, None] < M) & (offs_n[None, :] < N),
other=0.0,
)
x = tl.load(
x_ptr + m[:, None] * K + k[None, :],
mask=(m[:, None] < M) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
other=0.0,
)
acc += tl.dot(tl.trans(grad), x, input_precision="ieee")
grad_sign = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int32)
lin = offs_n[:, None] * K + k[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(
packed_ptr + pack_idx,
mask=(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
other=0,
).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
ternary = trit.to(tl.int32) - 1
contrib = tl.where(
(offs_n[:, None] < N) & (offs_r[None, :] < GROUP_SIZE) & (k[None, :] < K),
grad_sign * ternary,
0,
)
score = tl.sum(contrib, axis=1)
corr_idx = offs_n * GPR + pid_g
old_corr = tl.load(corr_ptr + corr_idx, mask=offs_n < N, other=0).to(tl.int64)
new_corr = old_corr - score.to(tl.int64) * CORR_STEP
tl.store(corr_ptr + corr_idx, new_corr, mask=offs_n < N)
@triton.jit
def _triton_apply_accumulated_flips_kernel(
packed_ptr, accum_ptr, per_group_threshold_ptr,
TOTAL: tl.constexpr, ACCUM_THRESHOLD: tl.constexpr,
K: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
HAS_PER_GROUP_THRESHOLD: tl.constexpr,
BLOCK_T: tl.constexpr,
):
pack_idx = tl.program_id(0)
offs_t = tl.arange(0, BLOCK_T)
valid_trit = offs_t < 5
lin = pack_idx * 5 + offs_t
valid = valid_trit & (lin < TOTAL)
old_packed = tl.load(packed_ptr + pack_idx).to(tl.int32)
divisor = tl.where(
offs_t == 0, 1,
tl.where(offs_t == 1, 3,
tl.where(offs_t == 2, 9,
tl.where(offs_t == 3, 27, 81))),
)
old_code = (old_packed // divisor) % 3
old_sign = old_code.to(tl.int32) - 1
old_accum = tl.load(accum_ptr + lin, mask=valid, other=0).to(tl.int32)
if HAS_PER_GROUP_THRESHOLD:
n = lin // K
k = lin - n * K
g_idx = n * GPR + k // GROUP_SIZE
threshold = tl.load(per_group_threshold_ptr + g_idx, mask=valid, other=ACCUM_THRESHOLD).to(tl.int32)
else:
threshold = ACCUM_THRESHOLD
flip_up = old_accum > threshold
flip_down = old_accum < -threshold
did_flip = valid & (flip_up | flip_down)
new_sign = tl.where(flip_up, 1, tl.where(flip_down, -1, old_sign))
stored_accum = tl.where(did_flip, 0, old_accum)
tl.store(accum_ptr + lin, stored_accum.to(tl.int8), mask=valid)
new_code = tl.where(valid, new_sign + 1, 0)
packed_val = tl.sum(new_code * divisor, axis=0)
tl.store(packed_ptr + pack_idx, packed_val.to(tl.uint8))
def _triton_ternary_forward(x_2d, packed, e, corr_accum, step_counter, n_out, k_in, group_size):
block_m, block_n, block_k = 16, 16, 32
out = torch.empty((x_2d.shape[0], n_out), device=x_2d.device, dtype=torch.float32)
grid = (triton.cdiv(x_2d.shape[0], block_m), triton.cdiv(n_out, block_n))
_triton_ternary_fwd_kernel[grid](
x_2d, packed, e, corr_accum, step_counter, out,
x_2d.shape[0], n_out, k_in, ceil(k_in / group_size), group_size,
_bigint_corr_strength(),
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
)
return out
def _triton_ternary_grad_x(grad_2d, packed, e, corr_accum, step_counter, m_rows, n_out, k_in, group_size):
block_m, block_n, block_k = 16, 16, 32
out = torch.empty((m_rows, k_in), device=grad_2d.device, dtype=torch.float32)
grid = (triton.cdiv(m_rows, block_m), triton.cdiv(k_in, block_k))
_triton_ternary_grad_x_kernel[grid](
grad_2d, packed, e, corr_accum, step_counter, out,
m_rows, n_out, k_in, ceil(k_in / group_size), group_size,
_bigint_corr_strength(),
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
)
return out
def _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in):
block_m, block_n, block_k = 32, 16, 32
out = torch.empty((n_out, k_in), device=grad_2d.device, dtype=torch.int8)
grid = (triton.cdiv(n_out, block_n), triton.cdiv(k_in, block_k))
_triton_ternary_grad_sign_kernel[grid](
grad_2d, x_2d, out,
x_2d.shape[0], n_out, k_in,
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
)
return out
def _triton_update_e(packed, grad_sign, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4):
block_n, block_g = 8, 4
gpr = ceil(k_in / group_size)
block_k = 1 << (group_size - 1).bit_length()
grid = (triton.cdiv(n_out, block_n), triton.cdiv(gpr, block_g))
_triton_update_e_kernel[grid](
packed, grad_sign, e, e_accum,
n_out, k_in, group_size, gpr, int(e_accum_threshold),
BLOCK_N=block_n, BLOCK_G=block_g, BLOCK_K=block_k,
)
def _triton_update_e_direct(packed, grad_2d, x_2d, e, e_accum, n_out, k_in, group_size, e_accum_threshold=4):
block_m, block_n = 32, 8
block_k = 1 << (group_size - 1).bit_length()
gpr = ceil(k_in / group_size)
grid = (triton.cdiv(n_out, block_n), gpr)
_triton_update_e_direct_kernel[grid](
packed, grad_2d, x_2d, e, e_accum,
x_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_threshold),
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
)
def _triton_ternary_step(packed, grad_sign, accum, total, accum_threshold, t_accum_step=1,
per_group_threshold=None, n_out=0, k_in=0, group_size=0):
block_t = 8
grid = (triton.cdiv(total, 5),)
has_pgt = per_group_threshold is not None
dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
_triton_ternary_step_kernel[grid](
packed, grad_sign, accum,
per_group_threshold if has_pgt else dummy,
total, accum_threshold, int(t_accum_step),
k_in if has_pgt else 0, gpr, group_size if has_pgt else 0,
has_pgt,
BLOCK_T=block_t,
)
def _triton_ternary_step_direct(packed, grad_2d, x_2d, accum, n_out, k_in, total, accum_threshold, t_accum_step=1,
per_group_threshold=None, group_size=0):
block_m, block_t = 32, 8
grid = (triton.cdiv(total, 5),)
has_pgt = per_group_threshold is not None
dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
_triton_ternary_step_direct_kernel[grid](
packed, grad_2d, x_2d, accum,
per_group_threshold if has_pgt else dummy,
x_2d.shape[0], n_out, k_in,
total, accum_threshold, int(t_accum_step),
gpr, group_size if has_pgt else 0,
has_pgt,
BLOCK_M=block_m, BLOCK_T=block_t,
)
def _triton_accumulate_direct(packed, grad_2d, x_2d, t_accum, e_accum,
n_out, k_in, group_size,
t_accum_step=1, e_accum_step=1,
update_scales=True):
block_m, block_t = 32, 8
total = n_out * k_in
grid = (triton.cdiv(total, 5),)
_triton_accumulate_t_direct_kernel[grid](
grad_2d, x_2d, t_accum,
grad_2d.shape[0], n_out, k_in, total, int(t_accum_step),
BLOCK_M=block_m, BLOCK_T=block_t,
)
if update_scales and e_accum is not None:
block_n = 8
block_k = 1 << (group_size - 1).bit_length()
gpr = ceil(k_in / group_size)
grid_e = (triton.cdiv(n_out, block_n), gpr)
_triton_accumulate_e_direct_kernel[grid_e](
packed, grad_2d, x_2d, e_accum,
grad_2d.shape[0], n_out, k_in, group_size, gpr, int(e_accum_step),
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
)
def _triton_accumulate_corr_direct(packed, grad_2d, x_2d, corr_accum,
n_out, k_in, group_size, corr_step=1):
block_m, block_n = 32, 8
block_k = 1 << (group_size - 1).bit_length()
gpr = ceil(k_in / group_size)
grid = (triton.cdiv(n_out, block_n), gpr)
_triton_accumulate_corr_direct_kernel[grid](
packed, grad_2d, x_2d, corr_accum,
grad_2d.shape[0], n_out, k_in, group_size, gpr, int(corr_step),
BLOCK_M=block_m, BLOCK_N=block_n, BLOCK_K=block_k,
)
def _triton_apply_accumulated_flips(packed, accum, total, accum_threshold,
per_group_threshold=None,
k_in=0, group_size=0):
block_t = 8
grid = (triton.cdiv(total, 5),)
has_pgt = per_group_threshold is not None
dummy = torch.empty(1, device=accum.device, dtype=torch.int8)
gpr = (k_in + group_size - 1) // group_size if has_pgt else 0
_triton_apply_accumulated_flips_kernel[grid](
packed, accum,
per_group_threshold if has_pgt else dummy,
total, accum_threshold,
k_in if has_pgt else 0, gpr, group_size if has_pgt else 0,
has_pgt,
BLOCK_T=block_t,
)
@triton.jit
def _triton_ternary_embed_fwd_kernel(
idx_ptr, packed_ptr, e_ptr, out_ptr,
NUM_IDX: tl.constexpr, DIM: tl.constexpr,
VOCAB: tl.constexpr, GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
offs_d = tl.arange(0, BLOCK_D)
idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32)
lin = idx[:, None] * DIM + offs_d[None, :]
pack_idx = lin // 5
trit_pos = lin - pack_idx * 5
packed = tl.load(packed_ptr + pack_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
sign = trit.to(tl.int32) - 1
e_idx = idx[:, None] * GPR + offs_d[None, :] // GROUP_SIZE
e_val = tl.load(e_ptr + e_idx, mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), other=0).to(tl.float32)
w = sign.to(tl.float32) * tl.exp2(e_val)
w = tl.where((offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM), w, 0.0)
tl.store(
out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
w,
mask=(offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM),
)
@triton.jit
def _triton_ternary_embed_bwd_accum_kernel(
idx_ptr, grad_ptr, accum_ptr,
NUM_IDX: tl.constexpr, DIM: tl.constexpr,
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)
offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
offs_d = tl.arange(0, BLOCK_D)
valid = (offs_b[:, None] < NUM_IDX) & (offs_d[None, :] < DIM)
idx = tl.load(idx_ptr + offs_b, mask=offs_b < NUM_IDX, other=0).to(tl.int32)
g = tl.load(grad_ptr + offs_b[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0)
dst = idx[:, None] * DIM + offs_d[None, :]
tl.atomic_add(accum_ptr + dst, g, mask=valid)
@triton.jit
def _triton_ternary_embed_bwd_sign_kernel(
accum_ptr, sign_ptr,
VOCAB: tl.constexpr, DIM: tl.constexpr,
BLOCK_V: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid_v = tl.program_id(0)
offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V)
offs_d = tl.arange(0, BLOCK_D)
valid = (offs_v[:, None] < VOCAB) & (offs_d[None, :] < DIM)
acc = tl.load(accum_ptr + offs_v[:, None] * DIM + offs_d[None, :], mask=valid, other=0.0)
sign_val = tl.where(acc > 0.0, 1, tl.where(acc < 0.0, -1, 0)).to(tl.int8)
tl.store(sign_ptr + offs_v[:, None] * DIM + offs_d[None, :], sign_val, mask=valid)
def _triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim):
flat_idx = indices.reshape(-1).contiguous().to(torch.int32)
grad_2d = grad_output.reshape(-1, dim).contiguous()
num_idx = flat_idx.shape[0]
accum = torch.zeros(vocab, dim, device=grad_output.device, dtype=torch.float32)
block_b = 64
grid = (triton.cdiv(num_idx, block_b),)
_triton_ternary_embed_bwd_accum_kernel[grid](
flat_idx, grad_2d, accum,
num_idx, dim,
BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
)
sign_out = torch.empty(vocab, dim, device=grad_output.device, dtype=torch.int8)
block_v = 32
grid2 = (triton.cdiv(vocab, block_v),)
_triton_ternary_embed_bwd_sign_kernel[grid2](
accum, sign_out,
vocab, dim,
BLOCK_V=block_v, BLOCK_D=triton.next_power_of_2(dim),
)
return sign_out
def _triton_ternary_embed(indices, packed, e, vocab, dim, group_size):
flat_idx = indices.reshape(-1).contiguous().to(torch.int32)
num_idx = flat_idx.shape[0]
out = torch.empty((num_idx, dim), device=indices.device, dtype=torch.float32)
block_b, block_d = 32, triton.next_power_of_2(dim)
gpr = ceil(dim / group_size)
grid = (triton.cdiv(num_idx, block_b),)
_triton_ternary_embed_fwd_kernel[grid](
flat_idx, packed, e, out,
num_idx, dim, vocab, gpr, group_size,
BLOCK_B=block_b, BLOCK_D=block_d,
)
return out.reshape(*indices.shape, dim)
class _TritonTernaryEmbedFn(torch.autograd.Function):
@staticmethod
def forward(ctx, indices, _dummy, module):
shape = tuple(module._T_shape.tolist())
vocab, dim = shape
packed = module.T_packed.contiguous()
e = module.E.contiguous()
ctx.save_for_backward(indices, packed, e)
ctx.module = module
ctx.shape = shape
ctx.group_size = module.group_size
comp_name, _ = _COMPONENT_CONTEXT.get()
ctx.comp_name = comp_name
return _triton_ternary_embed(indices, packed, e, vocab, dim, module.group_size)
@staticmethod
def backward(ctx, grad_output):
indices, packed, e = ctx.saved_tensors
vocab, dim = ctx.shape
grad_2d = grad_output.reshape(-1, dim).contiguous()
comp_name = ctx.comp_name
has_corr = hasattr(ctx.module, "corr_accum") and hasattr(ctx.module, "_accumulate_corr_from_grad_sign")
if getattr(ctx.module, "_stream_backward_updates", True) and has_corr:
# BigInt streaming: accumulate correlation directly
grad_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)
T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item())).to(device=grad_sign.device)
signed = grad_sign.to(torch.int16) * T.to(torch.int16)
ctx.module._accumulate_corr_from_grad_sign(grad_sign)
ctx.module._streamed_bigint_backward = True
elif comp_name is not None:
setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim))
T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item()))
setattr(ctx.module, f"_hook_T_{comp_name}", T.to(device=grad_2d.device))
else:
ctx.module._hook_grad_T_sign = _triton_ternary_embed_grad_sign(indices, grad_2d, vocab, dim)
T = unpack_ternary(packed, tuple(ctx.module._T_shape.tolist()), int(ctx.module._T_pad.item()))
ctx.module._hook_T = T.to(device=grad_2d.device)
return None, None, None
class _TritonTernaryLinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, module):
shape = tuple(module._T_shape.tolist())
n_out, k_in = shape
x_2d = x.reshape(-1, k_in).contiguous()
packed = module.T_packed.contiguous()
e = module.E.contiguous()
ctx.save_for_backward(x_2d, packed, e)
ctx.step_snapshot = int(module.step_counter.item())
ctx.x_shape = x.shape
ctx.shape = shape
ctx.group_size = module.group_size
ctx.module = module
comp_name, _ = _COMPONENT_CONTEXT.get()
ctx.comp_name = comp_name
corr = module.corr_accum.contiguous()
step = module.step_counter.contiguous()
out = _triton_ternary_forward(x_2d, packed, e, corr, step, n_out, k_in, module.group_size)
return out.reshape(*x.shape[:-1], n_out)
@staticmethod
def backward(ctx, grad_output):
x_2d, packed, e = ctx.saved_tensors
n_out, k_in = ctx.shape
grad_2d = grad_output.reshape(-1, n_out).contiguous()
corr = ctx.module.corr_accum.contiguous()
step = torch.tensor([ctx.step_snapshot], device=e.device, dtype=torch.int64)
grad_x = _triton_ternary_grad_x(
grad_2d, packed, e, corr, step, x_2d.shape[0], n_out, k_in, ctx.group_size
)
with torch.no_grad():
if getattr(ctx.module, "_stream_backward_updates", True):
_, bwd_weight = _COMPONENT_CONTEXT.get()
corr_step = max(1, int(round(abs(float(bwd_weight)))))
if bwd_weight < 0:
corr_step = -corr_step
_triton_accumulate_corr_direct(
packed, grad_2d, x_2d, ctx.module.corr_accum,
n_out, k_in, ctx.group_size, corr_step=corr_step,
)
ctx.module.step_counter.add_(abs(corr_step))
ctx.module._streamed_bigint_backward = True
else:
grad_sign = _triton_ternary_grad_sign(grad_2d, x_2d, n_out, k_in)
comp_name = ctx.comp_name
if comp_name is not None:
setattr(ctx.module, f"_hook_grad_T_sign_{comp_name}", grad_sign.detach())
else:
ctx.module._hook_grad_T_sign = grad_sign.detach()
return grad_x.reshape(*ctx.x_shape), None
class _BigIntTernaryLinearFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, module):
shape = tuple(module._T_shape.tolist())
n_out, k_in = shape
x_2d = x.reshape(-1, k_in).contiguous()
ctx.module = module
ctx.x_shape = x.shape
ctx.shape = shape
ctx.x_dtype = x.dtype
ctx.save_for_backward(x_2d)
with torch.no_grad():
w_eff = module.dequantize().to(device=x.device, dtype=torch.float32)
out = F.linear(x_2d.float(), w_eff, module.bias.float() if module.bias is not None else None)
return out.reshape(*x.shape[:-1], n_out)
@staticmethod
def backward(ctx, grad_output):
(x_2d,) = ctx.saved_tensors
module = ctx.module
n_out, k_in = ctx.shape
grad_2d = grad_output.reshape(-1, n_out).contiguous()
with torch.no_grad():
w_eff = module.dequantize().to(device=grad_2d.device, dtype=torch.float32)
grad_x = grad_2d.float() @ w_eff
grad_sign = (grad_2d.float().transpose(0, 1) @ x_2d.float()).sign().to(torch.int8)
module._accumulate_corr_from_grad_sign(grad_sign)
module._streamed_bigint_backward = True
return grad_x.reshape(*ctx.x_shape).to(dtype=ctx.x_dtype), None
"""
Log-Space Group Scale Representation
Convention (matching agents' Option B recommendation):
S = 2^E where S = scale, E = int8 log-space exponent
W_eff = T * 2^E
Key log-space properties exploited:
Multiplication → addition: S1 * S2 = 2^(E1 + E2)
Division → subtraction: S1 / S2 = 2^(E1 - E2)
Dequant → integer shift: 2^E * T = T << E (for E >= 0)
No IEEE floats in persistent state. E is stored as int8.
Ephemeral float only exists in autograd's computation graph.
"""
class TScaleType(IntEnum):
T4 = 4
T6 = 6
T8 = 8
T16 = 16
T32 = 32
T64 = 64
T96 = 96
GROUP_SIZES = {
TScaleType.T4: 4,
TScaleType.T6: 6,
TScaleType.T8: 8,
TScaleType.T16: 16,
TScaleType.T32: 32,
TScaleType.T64: 64,
TScaleType.T96: 96,
}
TILE_SIZE = 384
def _n_groups(shape, group_size):
out_dim, in_dim = shape
return out_dim * ceil(in_dim / group_size)
def _expand_E(E, shape, group_size):
out_dim, in_dim = shape
gpr = ceil(in_dim / group_size)
E_2d = E.view(out_dim, gpr)
E_exp = E_2d.repeat_interleave(group_size, dim=1)
if E_exp.shape[1] > in_dim:
E_exp = E_exp[:, :in_dim]
return E_exp
def _ternarize(x, threshold=0.05):
return x.sign() * (x.abs() > threshold).to(x.dtype)
def _scaled_init_threshold(threshold: float, init_std: float) -> float:
if init_std <= 0:
return threshold
return min(float(threshold), 0.5 * float(init_std))
class TernaryScaleTensor(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
threshold: float = 0.05,
weight_init_std: float | None = None,
tscale_type: TScaleType = TScaleType.T32,
bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
init_std = min(0.1, in_dim ** -0.5) if weight_init_std is None else float(weight_init_std)
init_threshold = _scaled_init_threshold(threshold, init_std)
self.threshold = init_threshold
self.tscale_type = tscale_type
self.group_size = GROUP_SIZES[tscale_type]
shape = (out_dim, in_dim)
n_grp = _n_groups(shape, self.group_size)
w_init = torch.randn(out_dim, in_dim) * init_std
T_init = _ternarize(w_init, init_threshold)
packed_T, T_shape, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([out_dim, in_dim], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
gpr = ceil(in_dim / self.group_size)
total_in = gpr * self.group_size
padded = torch.zeros(out_dim, total_in)
abs_w = w_init.abs()
padded[:, :in_dim] = abs_w
grouped = padded.view(out_dim, gpr, self.group_size)
grp_means = grouped.mean(dim=2)
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
E_int = E_vals.log2().clamp(-128, 127).to(torch.int8)
self.register_buffer("E", E_int.flatten())
self.register_buffer("corr_accum", torch.zeros_like(self.E, dtype=torch.int64))
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
if bias:
self.register_buffer("bias", torch.zeros(out_dim, dtype=torch.int32))
else:
self.bias = None
def _get_T(self):
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
def _get_S(self):
gpr = ceil(self.in_dim / self.group_size)
e_adj = self.E.float()
if hasattr(self, "corr_accum") and hasattr(self, "step_counter"):
step = int(self.step_counter.item())
if step > 0:
denom = max(step * self.group_size, 1)
e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
E_exp = _expand_E(e_adj, (self.out_dim, self.in_dim), self.group_size)
return torch.exp2(E_exp)
def _ensure_group_lr(self):
if not hasattr(self, "group_lr"):
self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device:
self.group_lr = torch.ones_like(self.E, dtype=torch.int8)
return self.group_lr
def precompile_kernels(self, M: int):
pass
def forward(self, x):
backend = _backend_preference()
if backend == "tilelang" and _HAS_TILELANG:
if torch.is_grad_enabled() and not _tilelang_training_enabled():
raise RuntimeError(
"ARB_TERNARY_BACKEND='tilelang' is inference-only by default. "
"BigInt ternary training should use ARB_TERNARY_BACKEND='triton'. "
"Set ARB_TILELANG_TRAINING=1 only for experimental TileLang training."
)
x_for_grad = x
if torch.is_grad_enabled() and not x.requires_grad:
x_for_grad = x.detach().requires_grad_(True)
N, K = tuple(self._T_shape.tolist())
x_2d = x_for_grad.reshape(-1, K)
M = x_2d.shape[0]
try:
fwd_kernel = _get_kernel(M, N, K, self.group_size, "fwd")
y = _TernaryLinearFn.apply(x_for_grad, self, fwd_kernel)
if self.bias is not None:
y = y + self.bias.float()
return y
except Exception as e:
warnings.warn(f"TileLang forward failed for {self._T_shape.tolist()}: {e}")
if _HAS_TRITON:
backend = "triton"
else:
backend = "torch"
if x.is_cuda and _HAS_TRITON and backend in {"auto", "triton"}:
x_for_grad = x
if torch.is_grad_enabled() and not x.requires_grad:
x_for_grad = x.detach().requires_grad_(True)
y = _TritonTernaryLinearFn.apply(x_for_grad, self)
if self.bias is not None:
y = y + self.bias.float()
return y
if backend == "triton":
raise RuntimeError("ARB_TERNARY_BACKEND='triton' requested, but Triton is unavailable for this input.")
x_for_grad = x
if torch.is_grad_enabled() and not x.requires_grad:
x_for_grad = x.detach().requires_grad_(True)
return _BigIntTernaryLinearFn.apply(x_for_grad, self)
@torch.no_grad()
def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
shape = tuple(self._T_shape.tolist())
out_dim, in_dim = shape
if tuple(grad_sign.shape) != shape:
return
T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
signed = grad_sign.to(torch.int16) * T
gpr = ceil(in_dim / self.group_size)
total_in = gpr * self.group_size
if total_in > in_dim:
signed = F.pad(signed, (0, total_in - in_dim))
score = signed.view(out_dim, gpr, self.group_size).sum(dim=2, dtype=torch.int16)
self.corr_accum -= score.flatten().to(device=self.corr_accum.device, dtype=torch.int64) * int(corr_step)
self.step_counter += abs(int(corr_step))
def ternary_step(self, lr=1, accum_threshold=None):
self._had_flip = False
if hasattr(self, "_hook_grad_T_sign"):
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
del self._hook_grad_T_sign
def update_E(self, lr=1, loss_signal=None):
has_dense_grad = hasattr(self, "_hook_grad_T_sign")
has_direct_grad = hasattr(self, "_hook_grad_2d") and hasattr(self, "_hook_x_2d")
if not has_dense_grad and not has_direct_grad:
return
if has_dense_grad:
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
del self._hook_grad_T_sign
else:
grad = self._hook_grad_2d.to(device=self.E.device, dtype=torch.float32)
x = self._hook_x_2d.to(device=self.E.device, dtype=torch.float32)
grad_sign = (grad.transpose(0, 1) @ x).sign().to(torch.int8)
self._accumulate_corr_from_grad_sign(grad_sign)
del self._hook_grad_2d
del self._hook_x_2d
if hasattr(self, "_hook_T"):
del self._hook_T
@property
def effective_bpw(self) -> float:
group_size = self.group_size
total = self._T_shape[0].item() * self._T_shape[1].item()
n_grp = _n_groups(tuple(self._T_shape.tolist()), group_size)
sign_bits = total * (8 / 5)
scale_bits = n_grp * 8.0
corr_bits = n_grp * 64.0
bias_bits = self.bias.numel() * 32.0 if self.bias is not None else 0.0
return (sign_bits + scale_bits + corr_bits + bias_bits) / total
def dequantize(self) -> torch.Tensor:
T = self._get_T().float()
S = self._get_S()
return S * T
def tscale_to(self, tscale_type: TScaleType):
self.tscale_type = tscale_type
old_group_size = self.group_size
self.group_size = GROUP_SIZES[tscale_type]
shape = tuple(self._T_shape.tolist())
out_dim, in_dim = shape
new_gpr = ceil(in_dim / self.group_size)
new_n_grp = out_dim * new_gpr
if self.E.shape[0] != new_n_grp:
T = self._get_T().float()
total_in = new_gpr * self.group_size
padded = torch.zeros(out_dim, total_in, device=self.T_packed.device)
abs_w = T.abs()
padded[:, :in_dim] = abs_w
grouped = padded.view(out_dim, new_gpr, self.group_size)
grp_means = grouped.mean(dim=2)
E_new = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
E_int = E_new.log2().clamp(-128, 127).to(torch.int8)
self.E = E_int.flatten()
self.corr_accum = torch.zeros_like(self.E, dtype=torch.int64)
self.step_counter = torch.zeros(1, dtype=torch.int64, device=self.E.device)
return self
tscale_cast = tscale_to
def extra_repr(self) -> str:
return (
f"in_dim={self.in_dim}, out_dim={self.out_dim}, "
f"tscale_type={self.tscale_type.name}, group_size={self.group_size}, "
f"effective_bpw={self.effective_bpw:.2f}"
)
if _HAS_TRITON:
@triton.jit
def _triton_rmsnorm_fwd_kernel(
x_ptr, packed_ptr, e_ptr, out_ptr,
BATCH: tl.constexpr, DIM: tl.constexpr,
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid_b = tl.program_id(0)
offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
offs_d = tl.arange(0, BLOCK_D)
x = tl.load(
x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
other=0.0,
)
sq = x * x
msq = tl.sum(sq, axis=1, keep_dims=True) / DIM
rms = tl.sqrt(msq + 1e-5)
x_norm = x / rms
pack_idx = offs_d // 5
trit_pos = offs_d - pack_idx * 5
packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
sign = trit.to(tl.int32) - 1
e_idx = offs_d // GROUP_SIZE
e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32)
w = sign.to(tl.float32) * tl.exp2(e_val)
w = tl.where(offs_d < DIM, w, 0.0)
out = x_norm * w[None, :]
tl.store(
out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
out,
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
)
@triton.jit
def _triton_rmsnorm_bwd_kernel(
grad_out_ptr, x_ptr, packed_ptr, e_ptr,
grad_x_ptr,
BATCH: tl.constexpr, DIM: tl.constexpr,
GPR: tl.constexpr, GROUP_SIZE: tl.constexpr,
BLOCK_B: tl.constexpr, BLOCK_D: tl.constexpr,
):
pid_b = tl.program_id(0)
offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
offs_d = tl.arange(0, BLOCK_D)
x = tl.load(
x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
other=0.0,
)
sq = x * x
msq = tl.sum(sq, axis=1, keep_dims=True) / DIM
rms = tl.sqrt(msq + 1e-5)
x_norm = x / rms
pack_idx = offs_d // 5
trit_pos = offs_d - pack_idx * 5
packed = tl.load(packed_ptr + pack_idx, mask=offs_d < DIM, other=0).to(tl.int32)
divisor = tl.where(
trit_pos == 0, 1,
tl.where(trit_pos == 1, 3,
tl.where(trit_pos == 2, 9,
tl.where(trit_pos == 3, 27, 81))),
)
trit = (packed // divisor) % 3
sign = trit.to(tl.int32) - 1
e_idx = offs_d // GROUP_SIZE
e_val = tl.load(e_ptr + e_idx, mask=offs_d < DIM, other=0).to(tl.float32)
w = sign.to(tl.float32) * tl.exp2(e_val)
w = tl.where(offs_d < DIM, w, 0.0)
dy = tl.load(
grad_out_ptr + offs_b[:, None] * DIM + offs_d[None, :],
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
other=0.0,
)
dyw = dy * w[None, :]
c1 = tl.sum(x_norm * dyw, axis=1, keep_dims=True) / DIM
dx = (dyw - x_norm * c1) / rms
tl.store(
grad_x_ptr + offs_b[:, None] * DIM + offs_d[None, :],
dx,
mask=(offs_b[:, None] < BATCH) & (offs_d[None, :] < DIM),
)
class _TritonRMSNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, module, packed, e, dim, group_size):
ctx.module = module
x_2d = x.reshape(-1, dim).contiguous()
batch = x_2d.shape[0]
out = torch.empty_like(x_2d)
block_b = 16
grid = (triton.cdiv(batch, block_b),)
_triton_rmsnorm_fwd_kernel[grid](
x_2d, packed, e, out,
batch, dim, ceil(dim / group_size), group_size,
BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
)
ctx.save_for_backward(x_2d, packed, e)
ctx.dim = dim
ctx.group_size = group_size
comp_name, _ = _COMPONENT_CONTEXT.get()
ctx.comp_name = comp_name
return out.reshape(*x.shape)
@staticmethod
def backward(ctx, grad_output):
x_2d, packed, e = ctx.saved_tensors
dim = ctx.dim
group_size = ctx.group_size
grad_2d = grad_output.reshape(-1, dim).contiguous()
batch = grad_2d.shape[0]
grad_x = torch.empty_like(x_2d)
block_b = 16
grid = (triton.cdiv(batch, block_b),)
_triton_rmsnorm_bwd_kernel[grid](
grad_2d, x_2d, packed, e, grad_x,
batch, dim, ceil(dim / group_size), group_size,
BLOCK_B=block_b, BLOCK_D=triton.next_power_of_2(dim),
)
return grad_x.reshape(*grad_output.shape), None, None, None, None, None
class TernaryRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5, threshold=0.05, tscale_type=TScaleType.T64):
super().__init__()
self.dim = dim
self.eps = eps
self.threshold = threshold
self.tscale_type = tscale_type
self.group_size = GROUP_SIZES[tscale_type]
shape = (1, dim)
n_grp = _n_groups(shape, self.group_size)
w_init = torch.ones(1, dim)
T_init = _ternarize(w_init, threshold)
packed_T, T_shape, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([1, dim], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
gpr = ceil(dim / self.group_size)
total_in = gpr * self.group_size
padded = torch.zeros(1, total_in)
abs_w = w_init.abs()
padded[:, :dim] = abs_w
grouped = padded.view(1, gpr, self.group_size)
grp_means = grouped.mean(dim=2)
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
self.register_buffer("T_accum", torch.zeros(1, dim, dtype=torch.int8))
def _ensure_E_accum(self):
if not hasattr(self, "E_accum"):
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
elif self.E_accum.shape != self.E.shape or self.E_accum.device != self.E.device:
self.E_accum = torch.zeros_like(self.E, dtype=torch.int8)
return self.E_accum
def _ensure_group_lr(self):
if not hasattr(self, "group_lr"):
self.register_buffer("group_lr", torch.ones_like(self.E, dtype=torch.int8))
elif self.group_lr.shape != self.E.shape or self.group_lr.device != self.E.device:
self.group_lr = torch.ones_like(self.E, dtype=torch.int8)
return self.group_lr
def _get_T(self):
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())).squeeze(0)
def forward(self, x):
if x.is_cuda and _HAS_TRITON and self.dim <= _rmsnorm_triton_max_dim():
return _TritonRMSNormFn.apply(
x, self, self.T_packed.contiguous(), self.E.contiguous(),
self.dim, self.group_size,
)
inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
if x.is_cuda:
# TernaryRMSNorm is initialized as an identity scale and does not
# train E/T. Avoid unpacking a full large-dim weight or launching
# the high-register Triton backward kernel on 8GB GPUs.
return x * inv_rms
T = self._get_T()
E_exp = _expand_E(self.E, tuple(self._T_shape.tolist()), self.group_size).squeeze(0)
S = torch.exp2(E_exp.float())
weight = S * T.float()
return weight * (x * inv_rms)
def ternary_step(self, lr=1, accum_threshold=3):
pass
def update_E(self, lr=1, loss_signal=None):
pass
def extra_repr(self):
return f"dim={self.dim}, tscale_type={self.tscale_type.name}"