| |
| """ |
| Ternary Transformer Inference Engine |
| |
| Full Qwen2 architecture inference using ternary (1.58-bit) linear layers |
| with AVX-512 optimized kernels. Zero multiplications in linear layers. |
| |
| Architecture: DeepSeek-R1-Distill-Qwen-1.5B |
| - 28 layers, hidden=1536, intermediate=8960 |
| - GQA: 12 heads, 2 KV heads, head_dim=128 |
| - SwiGLU MLP, RoPE, RMSNorm |
| |
| (c) 2026 OpenTransformers Ltd / Scott Bisset |
| """ |
|
|
| import os |
| import json |
| import ctypes |
| import numpy as np |
| from pathlib import Path |
| import time |
|
|
| |
| |
| |
| def load_kernel(so_path="ternary_kernel.so"): |
| lib = ctypes.CDLL(so_path) |
| |
| |
| lib.ternary_matvec_avx512.restype = None |
| lib.ternary_matvec_avx512.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_int, |
| ctypes.c_int, |
| ] |
| |
| |
| lib.rmsnorm_avx512.restype = None |
| lib.rmsnorm_avx512.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_int, |
| ctypes.c_float, |
| ] |
| |
| |
| lib.silu_avx512.restype = None |
| lib.silu_avx512.argtypes = [ctypes.c_void_p, ctypes.c_int] |
| |
| |
| lib.elemwise_mul_avx512.restype = None |
| lib.elemwise_mul_avx512.argtypes = [ |
| ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int |
| ] |
| |
| |
| lib.softmax.restype = None |
| lib.softmax.argtypes = [ctypes.c_void_p, ctypes.c_int] |
| |
| |
| lib.apply_rope.restype = None |
| lib.apply_rope.argtypes = [ |
| ctypes.c_void_p, ctypes.c_void_p, |
| ctypes.c_int, ctypes.c_int, ctypes.c_int, |
| ctypes.c_int, ctypes.c_float |
| ] |
| |
| return lib |
|
|
| |
| |
| |
| class TernaryLinear: |
| def __init__(self, pos_bits, neg_bits, scales, out_dim, in_dim, kernel): |
| self.pos = pos_bits |
| self.neg = neg_bits |
| self.scales = scales |
| self.out_dim = out_dim |
| self.in_dim = in_dim |
| self.kernel = kernel |
| |
| def forward(self, x): |
| """x: float32[in_dim] -> float32[out_dim]""" |
| y = np.zeros(self.out_dim, dtype=np.float32) |
| self.kernel.ternary_matvec_avx512( |
| self.pos.ctypes.data, |
| self.neg.ctypes.data, |
| self.scales.ctypes.data, |
| x.ctypes.data, |
| y.ctypes.data, |
| self.out_dim, |
| self.in_dim, |
| ) |
| return y |
|
|
| |
| |
| |
| class KVCache: |
| def __init__(self, n_layers, n_kv_heads, head_dim, max_seq=4096): |
| self.n_layers = n_layers |
| self.max_seq = max_seq |
| |
| self.k = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)] |
| self.v = [np.zeros((max_seq, n_kv_heads, head_dim), dtype=np.float32) for _ in range(n_layers)] |
| self.seq_len = 0 |
| |
| def append(self, layer, k, v): |
| """k, v: [n_kv_heads, head_dim]""" |
| pos = self.seq_len |
| self.k[layer][pos] = k |
| self.v[layer][pos] = v |
| |
| def get(self, layer): |
| """Returns k, v up to current position: [seq_len, n_kv_heads, head_dim]""" |
| return self.k[layer][:self.seq_len + 1], self.v[layer][:self.seq_len + 1] |
| |
| def advance(self): |
| self.seq_len += 1 |
|
|
| |
| |
| |
| class TernaryQwen: |
| def __init__(self, model_dir, kernel): |
| self.kernel = kernel |
| self.model_dir = model_dir |
| |
| with open(os.path.join(model_dir, "config.json")) as f: |
| self.config = json.load(f) |
| with open(os.path.join(model_dir, "manifest.json")) as f: |
| self.manifest = json.load(f) |
| |
| self.hidden = self.config["hidden_size"] |
| self.inter = self.config["intermediate_size"] |
| self.n_heads = self.config["num_attention_heads"] |
| self.n_kv = self.config["num_key_value_heads"] |
| self.head_dim = self.config["head_dim"] |
| self.n_layers = self.config["num_hidden_layers"] |
| self.vocab = self.config["vocab_size"] |
| self.rope_theta = self.config["rope_theta"] |
| self.eps = self.config["rms_norm_eps"] |
| |
| print(f"Loading ternary model: {self.n_layers} layers, " |
| f"hidden={self.hidden}, heads={self.n_heads}/{self.n_kv}") |
| |
| t0 = time.time() |
| self._load_weights() |
| print(f"Model loaded in {time.time()-t0:.1f}s") |
| |
| self._compute_memory() |
| |
| def _load_ternary(self, key): |
| """Load a ternary linear layer.""" |
| prefix = os.path.join(self.model_dir, key.replace(".", "_")) |
| shape = self.manifest["ternary"][key] |
| out_dim, in_dim = shape |
| chunks = (in_dim + 63) // 64 |
| |
| pos = np.fromfile(prefix + ".pos", dtype=np.uint64).reshape(out_dim, chunks) |
| neg = np.fromfile(prefix + ".neg", dtype=np.uint64).reshape(out_dim, chunks) |
| scales = np.fromfile(prefix + ".scales", dtype=np.float32) |
| |
| |
| pos = np.ascontiguousarray(pos) |
| neg = np.ascontiguousarray(neg) |
| |
| return TernaryLinear(pos, neg, scales, out_dim, in_dim, self.kernel) |
| |
| def _load_fp16(self, key): |
| """Load an FP16 tensor.""" |
| prefix = os.path.join(self.model_dir, key.replace(".", "_")) |
| shape = self.manifest["fp16"][key] |
| return np.fromfile(prefix + ".fp16", dtype=np.float16).reshape(shape).astype(np.float32) |
| |
| def _load_weights(self): |
| """Load all weights.""" |
| |
| self.embed = self._load_fp16("model.embed_tokens.weight") |
| |
| |
| self.final_norm = self._load_fp16("model.norm.weight") |
| |
| |
| if "lm_head.weight" in self.manifest.get("ternary", {}): |
| self.lm_head = self._load_ternary("lm_head.weight") |
| self.lm_head_ternary = True |
| elif "lm_head.weight" in self.manifest.get("fp16", {}): |
| self.lm_head_w = self._load_fp16("lm_head.weight") |
| self.lm_head_ternary = False |
| else: |
| |
| self.lm_head_w = self.embed |
| self.lm_head_ternary = False |
| |
| |
| self.layers = [] |
| for i in range(self.n_layers): |
| layer = {} |
| prefix = f"model.layers.{i}" |
| |
| |
| layer["q_proj"] = self._load_ternary(f"{prefix}.self_attn.q_proj.weight") |
| layer["k_proj"] = self._load_ternary(f"{prefix}.self_attn.k_proj.weight") |
| layer["v_proj"] = self._load_ternary(f"{prefix}.self_attn.v_proj.weight") |
| layer["o_proj"] = self._load_ternary(f"{prefix}.self_attn.o_proj.weight") |
| |
| |
| layer["gate_proj"] = self._load_ternary(f"{prefix}.mlp.gate_proj.weight") |
| layer["up_proj"] = self._load_ternary(f"{prefix}.mlp.up_proj.weight") |
| layer["down_proj"] = self._load_ternary(f"{prefix}.mlp.down_proj.weight") |
| |
| |
| layer["input_norm"] = self._load_fp16(f"{prefix}.input_layernorm.weight") |
| layer["post_norm"] = self._load_fp16(f"{prefix}.post_attention_layernorm.weight") |
| |
| |
| for proj in ["q_proj", "k_proj", "v_proj"]: |
| bias_key = f"{prefix}.self_attn.{proj}.bias" |
| if bias_key in self.manifest.get("fp16", {}): |
| layer[f"{proj}_bias"] = self._load_fp16(bias_key) |
| |
| self.layers.append(layer) |
| if (i + 1) % 7 == 0: |
| print(f" Loaded {i+1}/{self.n_layers} layers") |
| |
| print(f" Loaded {self.n_layers}/{self.n_layers} layers") |
| |
| def _compute_memory(self): |
| """Report memory usage.""" |
| ternary_bytes = 0 |
| fp_bytes = 0 |
| |
| for layer in self.layers: |
| for key in ["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"]: |
| tl = layer[key] |
| ternary_bytes += tl.pos.nbytes + tl.neg.nbytes + tl.scales.nbytes |
| for key in ["input_norm", "post_norm"]: |
| fp_bytes += layer[key].nbytes |
| |
| fp_bytes += self.embed.nbytes + self.final_norm.nbytes |
| if not self.lm_head_ternary: |
| fp_bytes += self.lm_head_w.nbytes if hasattr(self, 'lm_head_w') else 0 |
| |
| total = ternary_bytes + fp_bytes |
| print(f"\nMemory: ternary={ternary_bytes/1024/1024:.1f}MB, " |
| f"fp={fp_bytes/1024/1024:.1f}MB, total={total/1024/1024:.1f}MB") |
| |
| def _rmsnorm(self, x, weight): |
| """RMSNorm using C kernel.""" |
| y = np.zeros_like(x) |
| self.kernel.rmsnorm_avx512( |
| x.ctypes.data, weight.ctypes.data, y.ctypes.data, |
| len(x), ctypes.c_float(self.eps) |
| ) |
| return y |
| |
| def _attention(self, x, layer, cache, layer_idx, pos): |
| """Grouped-Query Attention.""" |
| h = self.hidden |
| n_h = self.n_heads |
| n_kv = self.n_kv |
| hd = self.head_dim |
| |
| |
| q = layer["q_proj"].forward(x) |
| k = layer["k_proj"].forward(x) |
| v = layer["v_proj"].forward(x) |
| |
| |
| if "q_proj_bias" in layer: |
| q += layer["q_proj_bias"] |
| if "k_proj_bias" in layer: |
| k += layer["k_proj_bias"] |
| if "v_proj_bias" in layer: |
| v += layer["v_proj_bias"] |
| |
| |
| q = q.reshape(n_h, hd) |
| k = k.reshape(n_kv, hd) |
| v = v.reshape(n_kv, hd) |
| |
| |
| self.kernel.apply_rope( |
| q.ctypes.data, k.ctypes.data, |
| n_h, n_kv, hd, pos, |
| ctypes.c_float(self.rope_theta) |
| ) |
| |
| |
| cache.append(layer_idx, k, v) |
| |
| |
| k_all, v_all = cache.get(layer_idx) |
| seq_len = k_all.shape[0] |
| |
| |
| heads_per_kv = n_h // n_kv |
| |
| |
| output = np.zeros(n_h * hd, dtype=np.float32) |
| scale = 1.0 / np.sqrt(hd) |
| |
| for head in range(n_h): |
| kv_head = head // heads_per_kv |
| q_h = q[head] |
| |
| |
| scores = np.dot(k_all[:, kv_head, :], q_h) * scale |
| |
| |
| |
| scores_max = np.max(scores) |
| scores = np.exp(scores - scores_max) |
| scores /= np.sum(scores) |
| |
| |
| out_h = np.dot(scores, v_all[:, kv_head, :]) |
| output[head * hd:(head + 1) * hd] = out_h |
| |
| |
| return layer["o_proj"].forward(output) |
| |
| def _mlp(self, x, layer): |
| """SwiGLU MLP.""" |
| gate = layer["gate_proj"].forward(x) |
| up = layer["up_proj"].forward(x) |
| |
| |
| self.kernel.silu_avx512(gate.ctypes.data, len(gate)) |
| |
| |
| self.kernel.elemwise_mul_avx512( |
| gate.ctypes.data, up.ctypes.data, gate.ctypes.data, len(gate) |
| ) |
| |
| |
| return layer["down_proj"].forward(gate) |
| |
| def forward_token(self, token_id, cache, pos): |
| """Forward pass for a single token.""" |
| |
| x = self.embed[token_id].copy() |
| |
| |
| for i, layer in enumerate(self.layers): |
| |
| normed = self._rmsnorm(x, layer["input_norm"]) |
| |
| |
| attn_out = self._attention(normed, layer, cache, i, pos) |
| x = x + attn_out |
| |
| |
| normed = self._rmsnorm(x, layer["post_norm"]) |
| |
| |
| mlp_out = self._mlp(normed, layer) |
| x = x + mlp_out |
| |
| |
| x = self._rmsnorm(x, self.final_norm) |
| |
| return x |
| |
| def logits(self, hidden): |
| """Compute logits from hidden state.""" |
| if self.lm_head_ternary: |
| return self.lm_head.forward(hidden) |
| else: |
| return hidden @ self.lm_head_w.T |
| |
| def generate(self, token_ids, max_new_tokens=256, temperature=0.6, top_p=0.95): |
| """Generate tokens autoregressively.""" |
| cache = KVCache(self.n_layers, self.n_kv, self.head_dim) |
| |
| generated = [] |
| all_tokens = list(token_ids) |
| |
| t_start = time.time() |
| |
| |
| for i, tid in enumerate(token_ids): |
| hidden = self.forward_token(tid, cache, i) |
| if i < len(token_ids) - 1: |
| cache.advance() |
| |
| t_prefill = time.time() - t_start |
| |
| |
| t_decode_start = time.time() |
| for step in range(max_new_tokens): |
| |
| logit_vec = self.logits(hidden) |
| |
| |
| if temperature < 0.01: |
| next_token = int(np.argmax(logit_vec)) |
| else: |
| logit_vec = logit_vec / temperature |
| |
| sorted_idx = np.argsort(logit_vec)[::-1] |
| sorted_logits = logit_vec[sorted_idx] |
| |
| |
| max_l = sorted_logits[0] |
| probs = np.exp(sorted_logits - max_l) |
| probs /= probs.sum() |
| |
| cumsum = np.cumsum(probs) |
| cutoff = np.searchsorted(cumsum, top_p) + 1 |
| |
| top_probs = probs[:cutoff] |
| top_probs /= top_probs.sum() |
| top_idx = sorted_idx[:cutoff] |
| |
| next_token = int(np.random.choice(top_idx, p=top_probs)) |
| |
| generated.append(next_token) |
| all_tokens.append(next_token) |
| |
| |
| if next_token in [151643, 151644, 151645]: |
| break |
| |
| cache.advance() |
| hidden = self.forward_token(next_token, cache, len(all_tokens) - 1) |
| |
| t_total = time.time() - t_start |
| t_decode = time.time() - t_decode_start |
| n_gen = len(generated) |
| |
| stats = { |
| "prefill_ms": t_prefill * 1000, |
| "decode_ms": t_decode * 1000, |
| "total_ms": t_total * 1000, |
| "tokens_generated": n_gen, |
| "tok_per_sec": n_gen / t_decode if t_decode > 0 else 0, |
| "prefill_tokens": len(token_ids), |
| } |
| |
| return generated, stats |
|
|
| |
| |
| |
| class Tokenizer: |
| def __init__(self, model_dir): |
| from tokenizers import Tokenizer as HFTokenizer |
| tok_path = os.path.join(model_dir, "tokenizer.json") |
| if os.path.exists(tok_path): |
| self.tok = HFTokenizer.from_file(tok_path) |
| else: |
| |
| from transformers import AutoTokenizer |
| self.tok = AutoTokenizer.from_pretrained(model_dir) |
| self._is_transformers = True |
| return |
| self._is_transformers = False |
| |
| def encode(self, text): |
| if self._is_transformers: |
| return self.tok.encode(text) |
| return self.tok.encode(text).ids |
| |
| def decode(self, ids): |
| if self._is_transformers: |
| return self.tok.decode(ids, skip_special_tokens=True) |
| return self.tok.decode(ids) |
| |
| def apply_chat_template(self, messages): |
| """Build Qwen chat format.""" |
| parts = [] |
| for msg in messages: |
| role = msg["role"] |
| content = msg["content"] |
| parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") |
| parts.append("<|im_start|>assistant\n") |
| return "".join(parts) |
|
|
| if __name__ == "__main__": |
| import sys |
| |
| model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-ternary" |
| kernel = load_kernel(os.path.join(os.path.dirname(__file__), "ternary_kernel.so")) |
| |
| model = TernaryQwen(model_dir, kernel) |
| |
| |
| test_ids = [151644, 8948, 198, 151645, 198, 151644, 872, 198, 9707, 151645, 198, 151644, 77091, 198] |
| |
| print("\nGenerating...") |
| tokens, stats = model.generate(test_ids, max_new_tokens=50, temperature=0.6) |
| print(f"Generated {stats['tokens_generated']} tokens") |
| print(f"Speed: {stats['tok_per_sec']:.1f} tok/s") |
| print(f"Prefill: {stats['prefill_ms']:.0f}ms, Decode: {stats['decode_ms']:.0f}ms") |
| print(f"Token IDs: {tokens}") |
|
|