File size: 9,738 Bytes
03f1e75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | """Convert Cactus Needle's Flax checkpoint to a PyTorch state_dict.
HF source: Cactus-Compute/needle / needle.pkl
Usage:
cd export
uv run python convert_weights.py
Output: export/artifacts/needle_torch.pt
"""
import pickle
import sys
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import hf_hub_download
# Make the PyTorch port importable from export/
sys.path.insert(0, str(Path(__file__).resolve().parent))
from needle_torch import NeedleModel, TransformerConfig
ART = Path(__file__).resolve().parent / "artifacts"
ART.mkdir(exist_ok=True)
_HF_REPO_DEFAULT = "Cactus-Compute/needle"
_HF_FILE_DEFAULT = "needle.pkl"
def load_flax_checkpoint(repo_id: str = _HF_REPO_DEFAULT, filename: str = _HF_FILE_DEFAULT):
"""Download a Cactus-format checkpoint from HF and return the raw dict.
Works for any model trained with Cactus's pipeline because the training code
always saves `{"config": <dict>, "params": <pytree>}` in the same shape.
Pass a different repo/filename to point at a finetuned variant — the rest
of this script reads `data["config"]` to parametrize the PyTorch port, so
dim changes (d_model, layer counts, GQA ratios) are picked up automatically.
"""
local_dir = str(ART)
print(f"Downloading {filename} from {repo_id}...", flush=True)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
local_dir=local_dir,
)
print(f"Loaded from {path}", flush=True)
with open(path, "rb") as f:
data = pickle.load(f)
return data
# ---------------------------------------------------------------------------
# Conversion helpers
# ---------------------------------------------------------------------------
def _to_f32(arr):
"""Convert any array-like (JAX, numpy, bfloat16) to a float32 numpy array."""
return np.asarray(arr).astype(np.float32)
def copy_kernel(new_state, flax_t, pt_name, i=None):
"""Copy a 2-D Linear kernel with Flax->PyTorch (in,out)->(out,in) transpose.
If i is not None, slice the leading scan dimension first.
"""
arr = _to_f32(flax_t)
if i is not None:
arr = arr[i] # (in, out)
arr = arr.T # (out, in)
new_state[pt_name] = torch.from_numpy(arr.copy())
def copy_vector(new_state, flax_t, pt_name, i=None):
"""Copy a 1-D scale / bias or a 0-D scalar (no transpose)."""
arr = _to_f32(flax_t)
if i is not None:
arr = arr[i]
new_state[pt_name] = torch.from_numpy(np.array(arr).copy())
# ---------------------------------------------------------------------------
# Main conversion
# ---------------------------------------------------------------------------
def main():
import argparse
p = argparse.ArgumentParser(description=(
"Convert a Cactus-format Flax checkpoint to a PyTorch state_dict for the "
"needle_torch port. Defaults to the published Cactus-Compute/needle weights; "
"pass --ckpt-repo / --ckpt-file to convert a finetuned variant."
))
p.add_argument("--ckpt-repo", default=_HF_REPO_DEFAULT,
help=f"HF repo containing the checkpoint (default: {_HF_REPO_DEFAULT})")
p.add_argument("--ckpt-file", default=_HF_FILE_DEFAULT,
help=f"Filename within the repo (default: {_HF_FILE_DEFAULT})")
p.add_argument("--out", default=str(ART / "needle_torch.pt"),
help="Output path for the PyTorch state_dict (default: artifacts/needle_torch.pt)")
args = p.parse_args()
# ---- Step 1: download + load Flax checkpoint ----
data = load_flax_checkpoint(args.ckpt_repo, args.ckpt_file)
config_dict = data["config"]
print(f"\nCheckpoint config: {config_dict}\n")
flax_params = data["params"]
# ---- Step 2: instantiate PyTorch port with checkpoint config ----
pt_config = TransformerConfig(**config_dict)
model = NeedleModel(pt_config)
model.eval()
target_state = model.state_dict()
# ---- Step 3: walk Flax tree and fill new_state ----
new_state = {}
# --- Top-level scalars ---
copy_vector(new_state, flax_params["log_temp"], "log_temp")
# --- Shared embedding (no transpose -- Flax Embed stores (vocab, d_model)) ---
# The state_dict includes the shared weight under three keys:
# embedding.weight, encoder.embedding.weight, decoder.embedding.weight
emb_tensor = torch.from_numpy(_to_f32(flax_params["embedding"]["embedding"]).copy())
new_state["embedding.weight"] = emb_tensor
new_state["encoder.embedding.weight"] = emb_tensor
new_state["decoder.embedding.weight"] = emb_tensor
# --- Contrastive head ---
# contrastive_hidden: kernel (d_model, d_model//4), bias (d_model//4,)
copy_kernel(new_state, flax_params["contrastive_hidden"]["kernel"], "contrastive_hidden.weight")
copy_vector(new_state, flax_params["contrastive_hidden"]["bias"], "contrastive_hidden.bias")
# contrastive_proj: kernel (d_model//4, contrastive_dim), no bias
copy_kernel(new_state, flax_params["contrastive_proj"]["kernel"], "contrastive_proj.weight")
# --- Encoder final norm ---
copy_vector(new_state, flax_params["encoder"]["final_norm"]["scale"], "encoder.final_norm.scale")
# --- Encoder layers (nn.scan: EncoderBlock_0 has leading dim = num_encoder_layers) ---
enc_block = flax_params["encoder"]["layers"]["EncoderBlock_0"]
for i in range(pt_config.num_encoder_layers):
base = f"encoder.layers.{i}"
# attn_gate: scalar at index i
copy_vector(new_state, enc_block["attn_gate"], f"{base}.attn_gate", i)
# pre-norm (ZCRMSNorm_0.scale[i] -> layers.i.norm.scale)
copy_vector(new_state, enc_block["ZCRMSNorm_0"]["scale"], f"{base}.norm.scale", i)
# self-attention projections (all Linear kernels need transpose)
sa = enc_block["self_attn"]
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i)
# QK norms (scale vectors, no transpose)
for n in ["q_norm", "k_norm"]:
copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i)
# --- Decoder final norm ---
# Flax: decoder.ZCRMSNorm_0.scale -> PyTorch: decoder.final_norm.scale
copy_vector(new_state, flax_params["decoder"]["ZCRMSNorm_0"]["scale"], "decoder.final_norm.scale")
# --- Decoder layers (nn.scan: DecoderBlock_0 has leading dim = num_decoder_layers) ---
dec_block = flax_params["decoder"]["layers"]["DecoderBlock_0"]
for i in range(pt_config.num_decoder_layers):
base = f"decoder.layers.{i}"
# Gates
copy_vector(new_state, dec_block["self_attn_gate"], f"{base}.self_attn_gate", i)
copy_vector(new_state, dec_block["cross_attn_gate"], f"{base}.cross_attn_gate", i)
# Pre-norms
# ZCRMSNorm_0 = self-attn pre-norm -> self_norm
copy_vector(new_state, dec_block["ZCRMSNorm_0"]["scale"], f"{base}.self_norm.scale", i)
# ZCRMSNorm_1 = cross-attn pre-norm -> cross_norm
copy_vector(new_state, dec_block["ZCRMSNorm_1"]["scale"], f"{base}.cross_norm.scale", i)
# Self-attention projections
sa = dec_block["self_attn"]
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i)
for n in ["q_norm", "k_norm"]:
copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i)
# Cross-attention projections
ca = dec_block["cross_attn"]
for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
copy_kernel(new_state, ca[proj]["kernel"], f"{base}.cross_attn.{proj}.weight", i)
for n in ["q_norm", "k_norm"]:
copy_vector(new_state, ca[n]["scale"], f"{base}.cross_attn.{n}.scale", i)
# ---- Step 4: verify completeness before loading ----
missing = sorted(set(target_state.keys()) - set(new_state.keys()))
extra = sorted(set(new_state.keys()) - set(target_state.keys()))
if missing or extra:
print("MISSING keys (in model, not in new_state):")
for k in missing:
print(f" {k}")
print("EXTRA keys (in new_state, not in model):")
for k in extra:
print(f" {k}")
sys.exit("state_dict mismatch -- fix the mapping")
# Shape check before load_state_dict
shape_errors = []
for k in new_state:
expected = tuple(target_state[k].shape)
got = tuple(new_state[k].shape)
if expected != got:
shape_errors.append(f" {k}: model expects {expected}, got {got}")
if shape_errors:
print("SHAPE MISMATCHES:")
for e in shape_errors:
print(e)
sys.exit("shape mismatch -- fix transpositions")
# ---- Step 5: load and verify ----
result = model.load_state_dict(new_state, strict=True)
assert result.missing_keys == [] and result.unexpected_keys == [], \
f"load_state_dict unexpected result: {result}"
n = len(new_state)
print(f"\nSuccessfully loaded {n} tensors into PyTorch port (strict=True)")
print(f"Config: {config_dict}")
# ---- Step 6: save ----
out_path = Path(args.out)
torch.save(new_state, out_path)
print(f"Saved -> {out_path}")
# Also save the config as JSON next to the .pt so export_onnx.py can rebuild
# the model with the right dims for any finetuned variant.
import json
config_out = out_path.with_suffix(".config.json")
config_out.write_text(json.dumps(config_dict, indent=2))
print(f"Saved -> {config_out}")
if __name__ == "__main__":
main()
|