import os import sys import torch def pack_ternary(w): q = torch.empty_like(w, dtype=torch.uint8) q[w < 0] = 0 q[w == 0] = 1 q[w > 0] = 2 flat = q.flatten() pad = (-len(flat)) % 4 if pad: flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)]) flat = flat.view(-1, 4) packed = ( flat[:, 0] | (flat[:, 1] << 2) | (flat[:, 2] << 4) | (flat[:, 3] << 6) ) return packed.cpu(), w.shape def save_model(model, path="trigram-morph.pt"): ternary_weights = {} for name, param in model.named_parameters(): if "weight" in name and param.ndim >= 2 and "embed" not in name: T = StickyZoneSTE.apply(param.data, THRESHOLD) packed, shape = pack_ternary(T) ternary_weights[name] = {"packed": packed, "shape": shape} torch.save({ "model_state_dict": model.state_dict(), "config": { "vocab": VOCAB, "embedding_dim": EMBEDDING_DIM, "trigram_dim": HIDDEN_DIM, "ffn_hidden": FFN_HIDDEN, "ctx": CTX, "threshold": THRESHOLD, }, "ternary_packed": ternary_weights, "format": "factorized_scaled_ternary", "bpw": 1.58, }, path) total = sum(p.numel() for p in model.parameters()) print(f"Saved {total:,} params to {path}") def load_model(path="trigram-morph.pt", device="cpu"): checkpoint = torch.load(path, map_location=device, weights_only=False) model = ARBModel() model.load_state_dict(checkpoint["model_state_dict"], strict=False) model.to(device) model.eval() return model if __name__ == "__main__": from ..trigram import ARBModel model = ARBModel() total = sum(p.numel() for p in model.parameters()) ternary = sum( p.numel() for n, p in model.named_parameters() if "weight" in n and p.ndim >= 2 and "embed" not in n ) fp32 = sum( p.numel() for n, p in model.named_parameters() if not ("weight" in n and p.ndim >= 2 and "embed" not in n) ) print(f"Total params: {total:,}") print(f"Ternary params (1.58 BPW): {ternary:,}") print(f"FP32 params: {fp32:,}") save_model(model)