| |
| """ |
| Convert DeepSeek-R1-Distill-Qwen-1.5B to ternary format. |
| |
| Stores linear weights as bitplanes (pos_mask, neg_mask) + per-row scale. |
| Embeddings and layernorms stay FP16. LM head stays FP16. |
| |
| (c) 2026 OpenTransformers Ltd / Scott Bisset |
| """ |
|
|
| import os |
| import json |
| import struct |
| import numpy as np |
| from pathlib import Path |
| import time |
|
|
| def load_safetensors(model_dir): |
| """Load all tensors from safetensors files.""" |
| import torch; from safetensors.torch import load_file |
| |
| tensors = {} |
| for f in sorted(Path(model_dir).glob("*.safetensors")): |
| print(f"Loading {f.name}...") |
| state = load_file(str(f)) |
| for key, val in state.items(): |
| tensors[key] = val.float().numpy() |
| return tensors |
|
|
| def quantize_row_ternary(row, alpha=0.7): |
| """Quantize a single row to ternary {-1, 0, +1}. Vectorized bitpacking.""" |
| row = row.astype(np.float32) |
| mean_abs = np.mean(np.abs(row)) |
| threshold = alpha * mean_abs |
| |
| pos = row >= threshold |
| neg = row <= -threshold |
| |
| nz_mask = pos | neg |
| scale = np.mean(np.abs(row[nz_mask])) if nz_mask.any() else np.float32(1.0) |
| |
| |
| in_dim = len(row) |
| pad = (64 - in_dim % 64) % 64 |
| if pad: |
| pos = np.concatenate([pos, np.zeros(pad, dtype=bool)]) |
| neg = np.concatenate([neg, np.zeros(pad, dtype=bool)]) |
| |
| |
| pos_r = pos.reshape(-1, 64).astype(np.uint64) |
| neg_r = neg.reshape(-1, 64).astype(np.uint64) |
| bit_positions = (np.uint64(1) << np.arange(64, dtype=np.uint64)) |
| pos_bits = np.bitwise_or.reduce(pos_r * bit_positions, axis=1) |
| neg_bits = np.bitwise_or.reduce(neg_r * bit_positions, axis=1) |
| |
| return pos_bits, neg_bits, np.float32(scale) |
|
|
| return pos_bits, neg_bits, np.float32(scale) |
|
|
| def quantize_weight_matrix(weight, alpha=0.7): |
| """Quantize entire weight matrix [out_dim, in_dim] to ternary. Fully vectorized.""" |
| w = weight.astype(np.float32) |
| out_dim, in_dim = w.shape |
| |
| |
| row_means = np.mean(np.abs(w), axis=1, keepdims=True) |
| thresholds = alpha * row_means |
| |
| pos = w >= thresholds |
| neg = w <= -thresholds |
| |
| |
| nz = pos | neg |
| |
| scales = np.zeros(out_dim, dtype=np.float32) |
| for i in range(out_dim): |
| if nz[i].any(): |
| scales[i] = np.mean(np.abs(w[i, nz[i]])) |
| else: |
| scales[i] = 1.0 |
| |
| |
| total = out_dim * in_dim |
| sparsity = 1.0 - np.sum(nz) / total |
| |
| |
| pad = (64 - in_dim % 64) % 64 |
| if pad: |
| pos = np.concatenate([pos, np.zeros((out_dim, pad), dtype=bool)], axis=1) |
| neg = np.concatenate([neg, np.zeros((out_dim, pad), dtype=bool)], axis=1) |
| |
| padded_dim = pos.shape[1] |
| chunks = padded_dim // 64 |
| |
| |
| bit_positions = (np.uint64(1) << np.arange(64, dtype=np.uint64)) |
| |
| pos_r = pos.reshape(out_dim, chunks, 64).astype(np.uint64) |
| neg_r = neg.reshape(out_dim, chunks, 64).astype(np.uint64) |
| |
| all_pos = np.bitwise_or.reduce(pos_r * bit_positions, axis=2) |
| all_neg = np.bitwise_or.reduce(neg_r * bit_positions, axis=2) |
| |
| return all_pos, all_neg, scales, sparsity |
|
|
| def save_ternary_model(tensors, output_dir, alpha=0.7): |
| """Convert and save full model to ternary format.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| config = { |
| "hidden_size": 1536, |
| "intermediate_size": 8960, |
| "num_attention_heads": 12, |
| "num_key_value_heads": 2, |
| "num_hidden_layers": 28, |
| "vocab_size": 151936, |
| "head_dim": 128, |
| "rope_theta": 1000000.0, |
| "rms_norm_eps": 1e-6, |
| "alpha": alpha, |
| } |
| |
| |
| ternary_keys = [] |
| keep_keys = [] |
| |
| for key in tensors: |
| if any(p in key for p in ['q_proj.weight', 'k_proj.weight', 'v_proj.weight', |
| 'o_proj.weight', 'gate_proj.weight', 'up_proj.weight', |
| 'down_proj.weight']): |
| ternary_keys.append(key) |
| else: |
| keep_keys.append(key) |
| |
| print(f"\nTernary layers: {len(ternary_keys)}") |
| print(f"FP16 layers: {len(keep_keys)}") |
| |
| |
| with open(os.path.join(output_dir, "config.json"), "w") as f: |
| json.dump(config, f, indent=2) |
| |
| |
| total_ternary_bytes = 0 |
| total_original_bytes = 0 |
| |
| for key in ternary_keys: |
| w = tensors[key].astype(np.float32) |
| out_dim, in_dim = w.shape |
| total_original_bytes += w.nbytes |
| |
| t0 = time.time() |
| pos, neg, scales, sparsity = quantize_weight_matrix(w, alpha) |
| dt = time.time() - t0 |
| |
| |
| prefix = os.path.join(output_dir, key.replace(".", "_")) |
| pos.tofile(prefix + ".pos") |
| neg.tofile(prefix + ".neg") |
| scales.tofile(prefix + ".scales") |
| |
| ternary_bytes = pos.nbytes + neg.nbytes + scales.nbytes |
| total_ternary_bytes += ternary_bytes |
| ratio = w.nbytes / ternary_bytes |
| |
| print(f" {key}: {w.shape} -> ternary ({ternary_bytes/1024:.0f}KB, " |
| f"{ratio:.1f}x compression, {sparsity:.1%} sparse, {dt:.1f}s)") |
| |
| |
| total_fp16_bytes = 0 |
| for key in keep_keys: |
| w = tensors[key].astype(np.float16) |
| prefix = os.path.join(output_dir, key.replace(".", "_")) |
| w.tofile(prefix + ".fp16") |
| total_fp16_bytes += w.nbytes |
| print(f" {key}: {w.shape} -> fp16 ({w.nbytes/1024:.0f}KB)") |
| |
| |
| manifest = { |
| "ternary": {k: list(tensors[k].shape) for k in ternary_keys}, |
| "fp16": {k: list(tensors[k].shape) for k in keep_keys}, |
| } |
| with open(os.path.join(output_dir, "manifest.json"), "w") as f: |
| json.dump(manifest, f, indent=2) |
| |
| total_bytes = total_ternary_bytes + total_fp16_bytes |
| orig_bytes = total_original_bytes + total_fp16_bytes |
| print(f"\n=== Summary ===") |
| print(f"Original FP32 linear weights: {total_original_bytes/1024/1024:.1f} MB") |
| print(f"Ternary linear weights: {total_ternary_bytes/1024/1024:.1f} MB") |
| print(f"FP16 other weights: {total_fp16_bytes/1024/1024:.1f} MB") |
| print(f"Total model size: {total_bytes/1024/1024:.1f} MB") |
| print(f"Compression vs FP32: {orig_bytes/total_bytes:.1f}x") |
|
|
| if __name__ == "__main__": |
| import sys |
| model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-hf" |
| output_dir = sys.argv[2] if len(sys.argv) > 2 else "deepseek-r1-1.5b-ternary" |
| alpha = float(sys.argv[3]) if len(sys.argv) > 3 else 0.7 |
| |
| print(f"Loading model from {model_dir}...") |
| tensors = load_safetensors(model_dir) |
| |
| print(f"Converting to ternary (alpha={alpha})...") |
| save_ternary_model(tensors, output_dir, alpha) |
| print("Done!") |
|
|