needle-onnx / convert_weights.py
shreyask's picture
Upload convert_weights.py with huggingface_hub
03f1e75 verified
"""Convert Cactus Needle's Flax checkpoint to a PyTorch state_dict.
HF source: Cactus-Compute/needle / needle.pkl
Usage:
cd export
uv run python convert_weights.py
Output: export/artifacts/needle_torch.pt
"""
import pickle
import sys
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import hf_hub_download
# Make the PyTorch port importable from export/
sys.path.insert(0, str(Path(__file__).resolve().parent))
from needle_torch import NeedleModel, TransformerConfig
ART = Path(__file__).resolve().parent / "artifacts"
ART.mkdir(exist_ok=True)
_HF_REPO_DEFAULT = "Cactus-Compute/needle"
_HF_FILE_DEFAULT = "needle.pkl"
def load_flax_checkpoint(repo_id: str = _HF_REPO_DEFAULT, filename: str = _HF_FILE_DEFAULT):
"""Download a Cactus-format checkpoint from HF and return the raw dict.
Works for any model trained with Cactus's pipeline because the training code
always saves `{"config": <dict>, "params": <pytree>}` in the same shape.
Pass a different repo/filename to point at a finetuned variant — the rest
of this script reads `data["config"]` to parametrize the PyTorch port, so
dim changes (d_model, layer counts, GQA ratios) are picked up automatically.
"""
local_dir = str(ART)
print(f"Downloading {filename} from {repo_id}...", flush=True)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
local_dir=local_dir,
)
print(f"Loaded from {path}", flush=True)
with open(path, "rb") as f:
data = pickle.load(f)
return data
# ---------------------------------------------------------------------------
# Conversion helpers
# ---------------------------------------------------------------------------
def _to_f32(arr):
"""Convert any array-like (JAX, numpy, bfloat16) to a float32 numpy array."""
return np.asarray(arr).astype(np.float32)
def copy_kernel(new_state, flax_t, pt_name, i=None):
"""Copy a 2-D Linear kernel with Flax->PyTorch (in,out)->(out,in) transpose.
If i is not None, slice the leading scan dimension first.
"""
arr = _to_f32(flax_t)
if i is not None:
arr = arr[i] # (in, out)
arr = arr.T # (out, in)
new_state[pt_name] = torch.from_numpy(arr.copy())
def copy_vector(new_state, flax_t, pt_name, i=None):
"""Copy a 1-D scale / bias or a 0-D scalar (no transpose)."""
arr = _to_f32(flax_t)
if i is not None:
arr = arr[i]
new_state[pt_name] = torch.from_numpy(np.array(arr).copy())
# ---------------------------------------------------------------------------
# Main conversion
# ---------------------------------------------------------------------------
def main():
import argparse
p = argparse.ArgumentParser(description=(
"Convert a Cactus-format Flax checkpoint to a PyTorch state_dict for the "
"needle_torch port. Defaults to the published Cactus-Compute/needle weights; "
"pass --ckpt-repo / --ckpt-file to convert a finetuned variant."
))
p.add_argument("--ckpt-repo", default=_HF_REPO_DEFAULT,
help=f"HF repo containing the checkpoint (default: {_HF_REPO_DEFAULT})")
p.add_argument("--ckpt-file", default=_HF_FILE_DEFAULT,
help=f"Filename within the repo (default: {_HF_FILE_DEFAULT})")
p.add_argument("--out", default=str(ART / "needle_torch.pt"),
help="Output path for the PyTorch state_dict (default: artifacts/needle_torch.pt)")
args = p.parse_args()
# ---- Step 1: download + load Flax checkpoint ----
data = load_flax_checkpoint(args.ckpt_repo, args.ckpt_file)
config_dict = data["config"]
print(f"\nCheckpoint config: {config_dict}\n")
flax_params = data["params"]
# ---- Step 2: instantiate PyTorch port with checkpoint config ----
pt_config = TransformerConfig(**config_dict)
model = NeedleModel(pt_config)
model.eval()
target_state = model.state_dict()
# ---- Step 3: walk Flax tree and fill new_state ----
new_state = {}
# --- Top-level scalars ---
copy_vector(new_state, flax_params["log_temp"], "log_temp")
# --- Shared embedding (no transpose -- Flax Embed stores (vocab, d_model)) ---
# The state_dict includes the shared weight under three keys:
# embedding.weight, encoder.embedding.weight, decoder.embedding.weight
emb_tensor = torch.from_numpy(_to_f32(flax_params["embedding"]["embedding"]).copy())
new_state["embedding.weight"] = emb_tensor
new_state["encoder.embedding.weight"] = emb_tensor
new_state["decoder.embedding.weight"] = emb_tensor
# --- Contrastive head ---
# contrastive_hidden: kernel (d_model, d_model//4), bias (d_model//4,)
copy_kernel(new_state, flax_params["contrastive_hidden"]["kernel"], "contrastive_hidden.weight")
copy_vector(new_state, flax_params["contrastive_hidden"]["bias"], "contrastive_hidden.bias")
# contrastive_proj: kernel (d_model//4, contrastive_dim), no bias
copy_kernel(new_state, flax_params["contrastive_proj"]["kernel"], "contrastive_proj.weight")
# --- Encoder final norm ---
copy_vector(new_state, flax_params["encoder"]["final_norm"]["scale"], "encoder.final_norm.scale")
# --- Encoder layers (nn.scan: EncoderBlock_0 has leading dim = num_encoder_layers) ---
enc_block = flax_params["encoder"]["layers"]["EncoderBlock_0"]
for i in range(pt_config.num_encoder_layers):
base = f"encoder.layers.{i}"
# attn_gate: scalar at index i
copy_vector(new_state, enc_block["attn_gate"], f"{base}.attn_gate", i)
# pre-norm (ZCRMSNorm_0.scale[i] -> layers.i.norm.scale)
copy_vector(new_state, enc_block["ZCRMSNorm_0"]["scale"], f"{base}.norm.scale", i)
# self-attention projections (all Linear kernels need transpose)
sa = enc_block["self_attn"]
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i)
# QK norms (scale vectors, no transpose)
for n in ["q_norm", "k_norm"]:
copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i)
# --- Decoder final norm ---
# Flax: decoder.ZCRMSNorm_0.scale -> PyTorch: decoder.final_norm.scale
copy_vector(new_state, flax_params["decoder"]["ZCRMSNorm_0"]["scale"], "decoder.final_norm.scale")
# --- Decoder layers (nn.scan: DecoderBlock_0 has leading dim = num_decoder_layers) ---
dec_block = flax_params["decoder"]["layers"]["DecoderBlock_0"]
for i in range(pt_config.num_decoder_layers):
base = f"decoder.layers.{i}"
# Gates
copy_vector(new_state, dec_block["self_attn_gate"], f"{base}.self_attn_gate", i)
copy_vector(new_state, dec_block["cross_attn_gate"], f"{base}.cross_attn_gate", i)
# Pre-norms
# ZCRMSNorm_0 = self-attn pre-norm -> self_norm
copy_vector(new_state, dec_block["ZCRMSNorm_0"]["scale"], f"{base}.self_norm.scale", i)
# ZCRMSNorm_1 = cross-attn pre-norm -> cross_norm
copy_vector(new_state, dec_block["ZCRMSNorm_1"]["scale"], f"{base}.cross_norm.scale", i)
# Self-attention projections
sa = dec_block["self_attn"]
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i)
for n in ["q_norm", "k_norm"]:
copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i)
# Cross-attention projections
ca = dec_block["cross_attn"]
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
copy_kernel(new_state, ca[proj]["kernel"], f"{base}.cross_attn.{proj}.weight", i)
for n in ["q_norm", "k_norm"]:
copy_vector(new_state, ca[n]["scale"], f"{base}.cross_attn.{n}.scale", i)
# ---- Step 4: verify completeness before loading ----
missing = sorted(set(target_state.keys()) - set(new_state.keys()))
extra = sorted(set(new_state.keys()) - set(target_state.keys()))
if missing or extra:
print("MISSING keys (in model, not in new_state):")
for k in missing:
print(f" {k}")
print("EXTRA keys (in new_state, not in model):")
for k in extra:
print(f" {k}")
sys.exit("state_dict mismatch -- fix the mapping")
# Shape check before load_state_dict
shape_errors = []
for k in new_state:
expected = tuple(target_state[k].shape)
got = tuple(new_state[k].shape)
if expected != got:
shape_errors.append(f" {k}: model expects {expected}, got {got}")
if shape_errors:
print("SHAPE MISMATCHES:")
for e in shape_errors:
print(e)
sys.exit("shape mismatch -- fix transpositions")
# ---- Step 5: load and verify ----
result = model.load_state_dict(new_state, strict=True)
assert result.missing_keys == [] and result.unexpected_keys == [], \
f"load_state_dict unexpected result: {result}"
n = len(new_state)
print(f"\nSuccessfully loaded {n} tensors into PyTorch port (strict=True)")
print(f"Config: {config_dict}")
# ---- Step 6: save ----
out_path = Path(args.out)
torch.save(new_state, out_path)
print(f"Saved -> {out_path}")
# Also save the config as JSON next to the .pt so export_onnx.py can rebuild
# the model with the right dims for any finetuned variant.
import json
config_out = out_path.with_suffix(".config.json")
config_out.write_text(json.dumps(config_dict, indent=2))
print(f"Saved -> {config_out}")
if __name__ == "__main__":
main()