| """Convert Cactus Needle's Flax checkpoint to a PyTorch state_dict. |
| |
| HF source: Cactus-Compute/needle / needle.pkl |
| |
| Usage: |
| cd export |
| uv run python convert_weights.py |
| |
| Output: export/artifacts/needle_torch.pt |
| """ |
|
|
| import pickle |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) |
| from needle_torch import NeedleModel, TransformerConfig |
|
|
| ART = Path(__file__).resolve().parent / "artifacts" |
| ART.mkdir(exist_ok=True) |
|
|
| _HF_REPO_DEFAULT = "Cactus-Compute/needle" |
| _HF_FILE_DEFAULT = "needle.pkl" |
|
|
|
|
| def load_flax_checkpoint(repo_id: str = _HF_REPO_DEFAULT, filename: str = _HF_FILE_DEFAULT): |
| """Download a Cactus-format checkpoint from HF and return the raw dict. |
| |
| Works for any model trained with Cactus's pipeline because the training code |
| always saves `{"config": <dict>, "params": <pytree>}` in the same shape. |
| Pass a different repo/filename to point at a finetuned variant — the rest |
| of this script reads `data["config"]` to parametrize the PyTorch port, so |
| dim changes (d_model, layer counts, GQA ratios) are picked up automatically. |
| """ |
| local_dir = str(ART) |
| print(f"Downloading {filename} from {repo_id}...", flush=True) |
| path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type="model", |
| local_dir=local_dir, |
| ) |
| print(f"Loaded from {path}", flush=True) |
| with open(path, "rb") as f: |
| data = pickle.load(f) |
| return data |
|
|
|
|
| |
| |
| |
|
|
| def _to_f32(arr): |
| """Convert any array-like (JAX, numpy, bfloat16) to a float32 numpy array.""" |
| return np.asarray(arr).astype(np.float32) |
|
|
|
|
| def copy_kernel(new_state, flax_t, pt_name, i=None): |
| """Copy a 2-D Linear kernel with Flax->PyTorch (in,out)->(out,in) transpose. |
| |
| If i is not None, slice the leading scan dimension first. |
| """ |
| arr = _to_f32(flax_t) |
| if i is not None: |
| arr = arr[i] |
| arr = arr.T |
| new_state[pt_name] = torch.from_numpy(arr.copy()) |
|
|
|
|
| def copy_vector(new_state, flax_t, pt_name, i=None): |
| """Copy a 1-D scale / bias or a 0-D scalar (no transpose).""" |
| arr = _to_f32(flax_t) |
| if i is not None: |
| arr = arr[i] |
| new_state[pt_name] = torch.from_numpy(np.array(arr).copy()) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| import argparse |
| p = argparse.ArgumentParser(description=( |
| "Convert a Cactus-format Flax checkpoint to a PyTorch state_dict for the " |
| "needle_torch port. Defaults to the published Cactus-Compute/needle weights; " |
| "pass --ckpt-repo / --ckpt-file to convert a finetuned variant." |
| )) |
| p.add_argument("--ckpt-repo", default=_HF_REPO_DEFAULT, |
| help=f"HF repo containing the checkpoint (default: {_HF_REPO_DEFAULT})") |
| p.add_argument("--ckpt-file", default=_HF_FILE_DEFAULT, |
| help=f"Filename within the repo (default: {_HF_FILE_DEFAULT})") |
| p.add_argument("--out", default=str(ART / "needle_torch.pt"), |
| help="Output path for the PyTorch state_dict (default: artifacts/needle_torch.pt)") |
| args = p.parse_args() |
|
|
| |
| data = load_flax_checkpoint(args.ckpt_repo, args.ckpt_file) |
|
|
| config_dict = data["config"] |
| print(f"\nCheckpoint config: {config_dict}\n") |
|
|
| flax_params = data["params"] |
|
|
| |
| pt_config = TransformerConfig(**config_dict) |
| model = NeedleModel(pt_config) |
| model.eval() |
|
|
| target_state = model.state_dict() |
|
|
| |
| new_state = {} |
|
|
| |
| copy_vector(new_state, flax_params["log_temp"], "log_temp") |
|
|
| |
| |
| |
| emb_tensor = torch.from_numpy(_to_f32(flax_params["embedding"]["embedding"]).copy()) |
| new_state["embedding.weight"] = emb_tensor |
| new_state["encoder.embedding.weight"] = emb_tensor |
| new_state["decoder.embedding.weight"] = emb_tensor |
|
|
| |
| |
| copy_kernel(new_state, flax_params["contrastive_hidden"]["kernel"], "contrastive_hidden.weight") |
| copy_vector(new_state, flax_params["contrastive_hidden"]["bias"], "contrastive_hidden.bias") |
|
|
| |
| copy_kernel(new_state, flax_params["contrastive_proj"]["kernel"], "contrastive_proj.weight") |
|
|
| |
| copy_vector(new_state, flax_params["encoder"]["final_norm"]["scale"], "encoder.final_norm.scale") |
|
|
| |
| enc_block = flax_params["encoder"]["layers"]["EncoderBlock_0"] |
| for i in range(pt_config.num_encoder_layers): |
| base = f"encoder.layers.{i}" |
|
|
| |
| copy_vector(new_state, enc_block["attn_gate"], f"{base}.attn_gate", i) |
|
|
| |
| copy_vector(new_state, enc_block["ZCRMSNorm_0"]["scale"], f"{base}.norm.scale", i) |
|
|
| |
| sa = enc_block["self_attn"] |
| for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]: |
| copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i) |
|
|
| |
| for n in ["q_norm", "k_norm"]: |
| copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i) |
|
|
| |
| |
| copy_vector(new_state, flax_params["decoder"]["ZCRMSNorm_0"]["scale"], "decoder.final_norm.scale") |
|
|
| |
| dec_block = flax_params["decoder"]["layers"]["DecoderBlock_0"] |
| for i in range(pt_config.num_decoder_layers): |
| base = f"decoder.layers.{i}" |
|
|
| |
| copy_vector(new_state, dec_block["self_attn_gate"], f"{base}.self_attn_gate", i) |
| copy_vector(new_state, dec_block["cross_attn_gate"], f"{base}.cross_attn_gate", i) |
|
|
| |
| |
| copy_vector(new_state, dec_block["ZCRMSNorm_0"]["scale"], f"{base}.self_norm.scale", i) |
| |
| copy_vector(new_state, dec_block["ZCRMSNorm_1"]["scale"], f"{base}.cross_norm.scale", i) |
|
|
| |
| sa = dec_block["self_attn"] |
| for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]: |
| copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i) |
| for n in ["q_norm", "k_norm"]: |
| copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i) |
|
|
| |
| ca = dec_block["cross_attn"] |
| for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]: |
| copy_kernel(new_state, ca[proj]["kernel"], f"{base}.cross_attn.{proj}.weight", i) |
| for n in ["q_norm", "k_norm"]: |
| copy_vector(new_state, ca[n]["scale"], f"{base}.cross_attn.{n}.scale", i) |
|
|
| |
| missing = sorted(set(target_state.keys()) - set(new_state.keys())) |
| extra = sorted(set(new_state.keys()) - set(target_state.keys())) |
| if missing or extra: |
| print("MISSING keys (in model, not in new_state):") |
| for k in missing: |
| print(f" {k}") |
| print("EXTRA keys (in new_state, not in model):") |
| for k in extra: |
| print(f" {k}") |
| sys.exit("state_dict mismatch -- fix the mapping") |
|
|
| |
| shape_errors = [] |
| for k in new_state: |
| expected = tuple(target_state[k].shape) |
| got = tuple(new_state[k].shape) |
| if expected != got: |
| shape_errors.append(f" {k}: model expects {expected}, got {got}") |
| if shape_errors: |
| print("SHAPE MISMATCHES:") |
| for e in shape_errors: |
| print(e) |
| sys.exit("shape mismatch -- fix transpositions") |
|
|
| |
| result = model.load_state_dict(new_state, strict=True) |
| assert result.missing_keys == [] and result.unexpected_keys == [], \ |
| f"load_state_dict unexpected result: {result}" |
|
|
| n = len(new_state) |
| print(f"\nSuccessfully loaded {n} tensors into PyTorch port (strict=True)") |
| print(f"Config: {config_dict}") |
|
|
| |
| out_path = Path(args.out) |
| torch.save(new_state, out_path) |
| print(f"Saved -> {out_path}") |
|
|
| |
| |
| import json |
| config_out = out_path.with_suffix(".config.json") |
| config_out.write_text(json.dumps(config_dict, indent=2)) |
| print(f"Saved -> {config_out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|