ARBS / arbitor /converters /convert_to_ternary2.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import os
import sys
import torch
def pack_ternary(w):
q = torch.empty_like(w, dtype=torch.uint8)
q[w < 0] = 0
q[w == 0] = 1
q[w > 0] = 2
flat = q.flatten()
pad = (-len(flat)) % 4
if pad:
flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)])
flat = flat.view(-1, 4)
packed = (
flat[:, 0]
| (flat[:, 1] << 2)
| (flat[:, 2] << 4)
| (flat[:, 3] << 6)
)
return packed.cpu(), w.shape
def save_model(model, path="trigram-morph.pt"):
ternary_weights = {}
for name, param in model.named_parameters():
if "weight" in name and param.ndim >= 2 and "embed" not in name:
T = StickyZoneSTE.apply(param.data, THRESHOLD)
packed, shape = pack_ternary(T)
ternary_weights[name] = {"packed": packed, "shape": shape}
torch.save({
"model_state_dict": model.state_dict(),
"config": {
"vocab": VOCAB,
"embedding_dim": EMBEDDING_DIM,
"trigram_dim": HIDDEN_DIM,
"ffn_hidden": FFN_HIDDEN,
"ctx": CTX,
"threshold": THRESHOLD,
},
"ternary_packed": ternary_weights,
"format": "factorized_scaled_ternary",
"bpw": 1.58,
}, path)
total = sum(p.numel() for p in model.parameters())
print(f"Saved {total:,} params to {path}")
def load_model(path="trigram-morph.pt", device="cpu"):
checkpoint = torch.load(path, map_location=device, weights_only=False)
model = ARBModel()
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
model.to(device)
model.eval()
return model
if __name__ == "__main__":
from ..trigram import ARBModel
model = ARBModel()
total = sum(p.numel() for p in model.parameters())
ternary = sum(
p.numel() for n, p in model.named_parameters()
if "weight" in n and p.ndim >= 2 and "embed" not in n
)
fp32 = sum(
p.numel() for n, p in model.named_parameters()
if not ("weight" in n and p.ndim >= 2 and "embed" not in n)
)
print(f"Total params: {total:,}")
print(f"Ternary params (1.58 BPW): {ternary:,}")
print(f"FP32 params: {fp32:,}")
save_model(model)