"""Convert Cactus Needle's Flax checkpoint to a PyTorch state_dict. HF source: Cactus-Compute/needle / needle.pkl Usage: cd export uv run python convert_weights.py Output: export/artifacts/needle_torch.pt """ import pickle import sys from pathlib import Path import numpy as np import torch from huggingface_hub import hf_hub_download # Make the PyTorch port importable from export/ sys.path.insert(0, str(Path(__file__).resolve().parent)) from needle_torch import NeedleModel, TransformerConfig ART = Path(__file__).resolve().parent / "artifacts" ART.mkdir(exist_ok=True) _HF_REPO_DEFAULT = "Cactus-Compute/needle" _HF_FILE_DEFAULT = "needle.pkl" def load_flax_checkpoint(repo_id: str = _HF_REPO_DEFAULT, filename: str = _HF_FILE_DEFAULT): """Download a Cactus-format checkpoint from HF and return the raw dict. Works for any model trained with Cactus's pipeline because the training code always saves `{"config": , "params": }` in the same shape. Pass a different repo/filename to point at a finetuned variant — the rest of this script reads `data["config"]` to parametrize the PyTorch port, so dim changes (d_model, layer counts, GQA ratios) are picked up automatically. """ local_dir = str(ART) print(f"Downloading {filename} from {repo_id}...", flush=True) path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="model", local_dir=local_dir, ) print(f"Loaded from {path}", flush=True) with open(path, "rb") as f: data = pickle.load(f) return data # --------------------------------------------------------------------------- # Conversion helpers # --------------------------------------------------------------------------- def _to_f32(arr): """Convert any array-like (JAX, numpy, bfloat16) to a float32 numpy array.""" return np.asarray(arr).astype(np.float32) def copy_kernel(new_state, flax_t, pt_name, i=None): """Copy a 2-D Linear kernel with Flax->PyTorch (in,out)->(out,in) transpose. If i is not None, slice the leading scan dimension first. """ arr = _to_f32(flax_t) if i is not None: arr = arr[i] # (in, out) arr = arr.T # (out, in) new_state[pt_name] = torch.from_numpy(arr.copy()) def copy_vector(new_state, flax_t, pt_name, i=None): """Copy a 1-D scale / bias or a 0-D scalar (no transpose).""" arr = _to_f32(flax_t) if i is not None: arr = arr[i] new_state[pt_name] = torch.from_numpy(np.array(arr).copy()) # --------------------------------------------------------------------------- # Main conversion # --------------------------------------------------------------------------- def main(): import argparse p = argparse.ArgumentParser(description=( "Convert a Cactus-format Flax checkpoint to a PyTorch state_dict for the " "needle_torch port. Defaults to the published Cactus-Compute/needle weights; " "pass --ckpt-repo / --ckpt-file to convert a finetuned variant." )) p.add_argument("--ckpt-repo", default=_HF_REPO_DEFAULT, help=f"HF repo containing the checkpoint (default: {_HF_REPO_DEFAULT})") p.add_argument("--ckpt-file", default=_HF_FILE_DEFAULT, help=f"Filename within the repo (default: {_HF_FILE_DEFAULT})") p.add_argument("--out", default=str(ART / "needle_torch.pt"), help="Output path for the PyTorch state_dict (default: artifacts/needle_torch.pt)") args = p.parse_args() # ---- Step 1: download + load Flax checkpoint ---- data = load_flax_checkpoint(args.ckpt_repo, args.ckpt_file) config_dict = data["config"] print(f"\nCheckpoint config: {config_dict}\n") flax_params = data["params"] # ---- Step 2: instantiate PyTorch port with checkpoint config ---- pt_config = TransformerConfig(**config_dict) model = NeedleModel(pt_config) model.eval() target_state = model.state_dict() # ---- Step 3: walk Flax tree and fill new_state ---- new_state = {} # --- Top-level scalars --- copy_vector(new_state, flax_params["log_temp"], "log_temp") # --- Shared embedding (no transpose -- Flax Embed stores (vocab, d_model)) --- # The state_dict includes the shared weight under three keys: # embedding.weight, encoder.embedding.weight, decoder.embedding.weight emb_tensor = torch.from_numpy(_to_f32(flax_params["embedding"]["embedding"]).copy()) new_state["embedding.weight"] = emb_tensor new_state["encoder.embedding.weight"] = emb_tensor new_state["decoder.embedding.weight"] = emb_tensor # --- Contrastive head --- # contrastive_hidden: kernel (d_model, d_model//4), bias (d_model//4,) copy_kernel(new_state, flax_params["contrastive_hidden"]["kernel"], "contrastive_hidden.weight") copy_vector(new_state, flax_params["contrastive_hidden"]["bias"], "contrastive_hidden.bias") # contrastive_proj: kernel (d_model//4, contrastive_dim), no bias copy_kernel(new_state, flax_params["contrastive_proj"]["kernel"], "contrastive_proj.weight") # --- Encoder final norm --- copy_vector(new_state, flax_params["encoder"]["final_norm"]["scale"], "encoder.final_norm.scale") # --- Encoder layers (nn.scan: EncoderBlock_0 has leading dim = num_encoder_layers) --- enc_block = flax_params["encoder"]["layers"]["EncoderBlock_0"] for i in range(pt_config.num_encoder_layers): base = f"encoder.layers.{i}" # attn_gate: scalar at index i copy_vector(new_state, enc_block["attn_gate"], f"{base}.attn_gate", i) # pre-norm (ZCRMSNorm_0.scale[i] -> layers.i.norm.scale) copy_vector(new_state, enc_block["ZCRMSNorm_0"]["scale"], f"{base}.norm.scale", i) # self-attention projections (all Linear kernels need transpose) sa = enc_block["self_attn"] for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]: copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i) # QK norms (scale vectors, no transpose) for n in ["q_norm", "k_norm"]: copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i) # --- Decoder final norm --- # Flax: decoder.ZCRMSNorm_0.scale -> PyTorch: decoder.final_norm.scale copy_vector(new_state, flax_params["decoder"]["ZCRMSNorm_0"]["scale"], "decoder.final_norm.scale") # --- Decoder layers (nn.scan: DecoderBlock_0 has leading dim = num_decoder_layers) --- dec_block = flax_params["decoder"]["layers"]["DecoderBlock_0"] for i in range(pt_config.num_decoder_layers): base = f"decoder.layers.{i}" # Gates copy_vector(new_state, dec_block["self_attn_gate"], f"{base}.self_attn_gate", i) copy_vector(new_state, dec_block["cross_attn_gate"], f"{base}.cross_attn_gate", i) # Pre-norms # ZCRMSNorm_0 = self-attn pre-norm -> self_norm copy_vector(new_state, dec_block["ZCRMSNorm_0"]["scale"], f"{base}.self_norm.scale", i) # ZCRMSNorm_1 = cross-attn pre-norm -> cross_norm copy_vector(new_state, dec_block["ZCRMSNorm_1"]["scale"], f"{base}.cross_norm.scale", i) # Self-attention projections sa = dec_block["self_attn"] for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]: copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i) for n in ["q_norm", "k_norm"]: copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i) # Cross-attention projections ca = dec_block["cross_attn"] for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]: copy_kernel(new_state, ca[proj]["kernel"], f"{base}.cross_attn.{proj}.weight", i) for n in ["q_norm", "k_norm"]: copy_vector(new_state, ca[n]["scale"], f"{base}.cross_attn.{n}.scale", i) # ---- Step 4: verify completeness before loading ---- missing = sorted(set(target_state.keys()) - set(new_state.keys())) extra = sorted(set(new_state.keys()) - set(target_state.keys())) if missing or extra: print("MISSING keys (in model, not in new_state):") for k in missing: print(f" {k}") print("EXTRA keys (in new_state, not in model):") for k in extra: print(f" {k}") sys.exit("state_dict mismatch -- fix the mapping") # Shape check before load_state_dict shape_errors = [] for k in new_state: expected = tuple(target_state[k].shape) got = tuple(new_state[k].shape) if expected != got: shape_errors.append(f" {k}: model expects {expected}, got {got}") if shape_errors: print("SHAPE MISMATCHES:") for e in shape_errors: print(e) sys.exit("shape mismatch -- fix transpositions") # ---- Step 5: load and verify ---- result = model.load_state_dict(new_state, strict=True) assert result.missing_keys == [] and result.unexpected_keys == [], \ f"load_state_dict unexpected result: {result}" n = len(new_state) print(f"\nSuccessfully loaded {n} tensors into PyTorch port (strict=True)") print(f"Config: {config_dict}") # ---- Step 6: save ---- out_path = Path(args.out) torch.save(new_state, out_path) print(f"Saved -> {out_path}") # Also save the config as JSON next to the .pt so export_onnx.py can rebuild # the model with the right dims for any finetuned variant. import json config_out = out_path.with_suffix(".config.json") config_out.write_text(json.dumps(config_dict, indent=2)) print(f"Saved -> {config_out}") if __name__ == "__main__": main()