| import os |
| import sys |
| import torch |
|
|
|
|
|
|
|
|
| def pack_ternary(w): |
| q = torch.empty_like(w, dtype=torch.uint8) |
| q[w < 0] = 0 |
| q[w == 0] = 1 |
| q[w > 0] = 2 |
|
|
| flat = q.flatten() |
| pad = (-len(flat)) % 4 |
| if pad: |
| flat = torch.cat([flat, torch.zeros(pad, dtype=torch.uint8, device=flat.device)]) |
|
|
| flat = flat.view(-1, 4) |
|
|
| packed = ( |
| flat[:, 0] |
| | (flat[:, 1] << 2) |
| | (flat[:, 2] << 4) |
| | (flat[:, 3] << 6) |
| ) |
|
|
| return packed.cpu(), w.shape |
|
|
|
|
| def save_model(model, path="trigram-morph.pt"): |
| ternary_weights = {} |
| for name, param in model.named_parameters(): |
| if "weight" in name and param.ndim >= 2 and "embed" not in name: |
| T = StickyZoneSTE.apply(param.data, THRESHOLD) |
| packed, shape = pack_ternary(T) |
| ternary_weights[name] = {"packed": packed, "shape": shape} |
|
|
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "config": { |
| "vocab": VOCAB, |
| "embedding_dim": EMBEDDING_DIM, |
| "trigram_dim": HIDDEN_DIM, |
| "ffn_hidden": FFN_HIDDEN, |
| "ctx": CTX, |
| "threshold": THRESHOLD, |
| }, |
| "ternary_packed": ternary_weights, |
| "format": "factorized_scaled_ternary", |
| "bpw": 1.58, |
| }, path) |
| total = sum(p.numel() for p in model.parameters()) |
| print(f"Saved {total:,} params to {path}") |
|
|
|
|
| def load_model(path="trigram-morph.pt", device="cpu"): |
| checkpoint = torch.load(path, map_location=device, weights_only=False) |
| model = ARBModel() |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| from ..trigram import ARBModel |
| model = ARBModel() |
| total = sum(p.numel() for p in model.parameters()) |
| ternary = sum( |
| p.numel() for n, p in model.named_parameters() |
| if "weight" in n and p.ndim >= 2 and "embed" not in n |
| ) |
| fp32 = sum( |
| p.numel() for n, p in model.named_parameters() |
| if not ("weight" in n and p.ndim >= 2 and "embed" not in n) |
| ) |
| print(f"Total params: {total:,}") |
| print(f"Ternary params (1.58 BPW): {ternary:,}") |
| print(f"FP32 params: {fp32:,}") |
| save_model(model) |
|
|