shreyask commited on
Commit
03f1e75
·
verified ·
1 Parent(s): cd8d04f

Upload convert_weights.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_weights.py +243 -0
convert_weights.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert Cactus Needle's Flax checkpoint to a PyTorch state_dict.
2
+
3
+ HF source: Cactus-Compute/needle / needle.pkl
4
+
5
+ Usage:
6
+ cd export
7
+ uv run python convert_weights.py
8
+
9
+ Output: export/artifacts/needle_torch.pt
10
+ """
11
+
12
+ import pickle
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import torch
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # Make the PyTorch port importable from export/
21
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
22
+ from needle_torch import NeedleModel, TransformerConfig
23
+
24
+ ART = Path(__file__).resolve().parent / "artifacts"
25
+ ART.mkdir(exist_ok=True)
26
+
27
+ _HF_REPO_DEFAULT = "Cactus-Compute/needle"
28
+ _HF_FILE_DEFAULT = "needle.pkl"
29
+
30
+
31
+ def load_flax_checkpoint(repo_id: str = _HF_REPO_DEFAULT, filename: str = _HF_FILE_DEFAULT):
32
+ """Download a Cactus-format checkpoint from HF and return the raw dict.
33
+
34
+ Works for any model trained with Cactus's pipeline because the training code
35
+ always saves `{"config": <dict>, "params": <pytree>}` in the same shape.
36
+ Pass a different repo/filename to point at a finetuned variant — the rest
37
+ of this script reads `data["config"]` to parametrize the PyTorch port, so
38
+ dim changes (d_model, layer counts, GQA ratios) are picked up automatically.
39
+ """
40
+ local_dir = str(ART)
41
+ print(f"Downloading {filename} from {repo_id}...", flush=True)
42
+ path = hf_hub_download(
43
+ repo_id=repo_id,
44
+ filename=filename,
45
+ repo_type="model",
46
+ local_dir=local_dir,
47
+ )
48
+ print(f"Loaded from {path}", flush=True)
49
+ with open(path, "rb") as f:
50
+ data = pickle.load(f)
51
+ return data
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Conversion helpers
56
+ # ---------------------------------------------------------------------------
57
+
58
+ def _to_f32(arr):
59
+ """Convert any array-like (JAX, numpy, bfloat16) to a float32 numpy array."""
60
+ return np.asarray(arr).astype(np.float32)
61
+
62
+
63
+ def copy_kernel(new_state, flax_t, pt_name, i=None):
64
+ """Copy a 2-D Linear kernel with Flax->PyTorch (in,out)->(out,in) transpose.
65
+
66
+ If i is not None, slice the leading scan dimension first.
67
+ """
68
+ arr = _to_f32(flax_t)
69
+ if i is not None:
70
+ arr = arr[i] # (in, out)
71
+ arr = arr.T # (out, in)
72
+ new_state[pt_name] = torch.from_numpy(arr.copy())
73
+
74
+
75
+ def copy_vector(new_state, flax_t, pt_name, i=None):
76
+ """Copy a 1-D scale / bias or a 0-D scalar (no transpose)."""
77
+ arr = _to_f32(flax_t)
78
+ if i is not None:
79
+ arr = arr[i]
80
+ new_state[pt_name] = torch.from_numpy(np.array(arr).copy())
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Main conversion
85
+ # ---------------------------------------------------------------------------
86
+
87
+ def main():
88
+ import argparse
89
+ p = argparse.ArgumentParser(description=(
90
+ "Convert a Cactus-format Flax checkpoint to a PyTorch state_dict for the "
91
+ "needle_torch port. Defaults to the published Cactus-Compute/needle weights; "
92
+ "pass --ckpt-repo / --ckpt-file to convert a finetuned variant."
93
+ ))
94
+ p.add_argument("--ckpt-repo", default=_HF_REPO_DEFAULT,
95
+ help=f"HF repo containing the checkpoint (default: {_HF_REPO_DEFAULT})")
96
+ p.add_argument("--ckpt-file", default=_HF_FILE_DEFAULT,
97
+ help=f"Filename within the repo (default: {_HF_FILE_DEFAULT})")
98
+ p.add_argument("--out", default=str(ART / "needle_torch.pt"),
99
+ help="Output path for the PyTorch state_dict (default: artifacts/needle_torch.pt)")
100
+ args = p.parse_args()
101
+
102
+ # ---- Step 1: download + load Flax checkpoint ----
103
+ data = load_flax_checkpoint(args.ckpt_repo, args.ckpt_file)
104
+
105
+ config_dict = data["config"]
106
+ print(f"\nCheckpoint config: {config_dict}\n")
107
+
108
+ flax_params = data["params"]
109
+
110
+ # ---- Step 2: instantiate PyTorch port with checkpoint config ----
111
+ pt_config = TransformerConfig(**config_dict)
112
+ model = NeedleModel(pt_config)
113
+ model.eval()
114
+
115
+ target_state = model.state_dict()
116
+
117
+ # ---- Step 3: walk Flax tree and fill new_state ----
118
+ new_state = {}
119
+
120
+ # --- Top-level scalars ---
121
+ copy_vector(new_state, flax_params["log_temp"], "log_temp")
122
+
123
+ # --- Shared embedding (no transpose -- Flax Embed stores (vocab, d_model)) ---
124
+ # The state_dict includes the shared weight under three keys:
125
+ # embedding.weight, encoder.embedding.weight, decoder.embedding.weight
126
+ emb_tensor = torch.from_numpy(_to_f32(flax_params["embedding"]["embedding"]).copy())
127
+ new_state["embedding.weight"] = emb_tensor
128
+ new_state["encoder.embedding.weight"] = emb_tensor
129
+ new_state["decoder.embedding.weight"] = emb_tensor
130
+
131
+ # --- Contrastive head ---
132
+ # contrastive_hidden: kernel (d_model, d_model//4), bias (d_model//4,)
133
+ copy_kernel(new_state, flax_params["contrastive_hidden"]["kernel"], "contrastive_hidden.weight")
134
+ copy_vector(new_state, flax_params["contrastive_hidden"]["bias"], "contrastive_hidden.bias")
135
+
136
+ # contrastive_proj: kernel (d_model//4, contrastive_dim), no bias
137
+ copy_kernel(new_state, flax_params["contrastive_proj"]["kernel"], "contrastive_proj.weight")
138
+
139
+ # --- Encoder final norm ---
140
+ copy_vector(new_state, flax_params["encoder"]["final_norm"]["scale"], "encoder.final_norm.scale")
141
+
142
+ # --- Encoder layers (nn.scan: EncoderBlock_0 has leading dim = num_encoder_layers) ---
143
+ enc_block = flax_params["encoder"]["layers"]["EncoderBlock_0"]
144
+ for i in range(pt_config.num_encoder_layers):
145
+ base = f"encoder.layers.{i}"
146
+
147
+ # attn_gate: scalar at index i
148
+ copy_vector(new_state, enc_block["attn_gate"], f"{base}.attn_gate", i)
149
+
150
+ # pre-norm (ZCRMSNorm_0.scale[i] -> layers.i.norm.scale)
151
+ copy_vector(new_state, enc_block["ZCRMSNorm_0"]["scale"], f"{base}.norm.scale", i)
152
+
153
+ # self-attention projections (all Linear kernels need transpose)
154
+ sa = enc_block["self_attn"]
155
+ for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
156
+ copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i)
157
+
158
+ # QK norms (scale vectors, no transpose)
159
+ for n in ["q_norm", "k_norm"]:
160
+ copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i)
161
+
162
+ # --- Decoder final norm ---
163
+ # Flax: decoder.ZCRMSNorm_0.scale -> PyTorch: decoder.final_norm.scale
164
+ copy_vector(new_state, flax_params["decoder"]["ZCRMSNorm_0"]["scale"], "decoder.final_norm.scale")
165
+
166
+ # --- Decoder layers (nn.scan: DecoderBlock_0 has leading dim = num_decoder_layers) ---
167
+ dec_block = flax_params["decoder"]["layers"]["DecoderBlock_0"]
168
+ for i in range(pt_config.num_decoder_layers):
169
+ base = f"decoder.layers.{i}"
170
+
171
+ # Gates
172
+ copy_vector(new_state, dec_block["self_attn_gate"], f"{base}.self_attn_gate", i)
173
+ copy_vector(new_state, dec_block["cross_attn_gate"], f"{base}.cross_attn_gate", i)
174
+
175
+ # Pre-norms
176
+ # ZCRMSNorm_0 = self-attn pre-norm -> self_norm
177
+ copy_vector(new_state, dec_block["ZCRMSNorm_0"]["scale"], f"{base}.self_norm.scale", i)
178
+ # ZCRMSNorm_1 = cross-attn pre-norm -> cross_norm
179
+ copy_vector(new_state, dec_block["ZCRMSNorm_1"]["scale"], f"{base}.cross_norm.scale", i)
180
+
181
+ # Self-attention projections
182
+ sa = dec_block["self_attn"]
183
+ for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
184
+ copy_kernel(new_state, sa[proj]["kernel"], f"{base}.self_attn.{proj}.weight", i)
185
+ for n in ["q_norm", "k_norm"]:
186
+ copy_vector(new_state, sa[n]["scale"], f"{base}.self_attn.{n}.scale", i)
187
+
188
+ # Cross-attention projections
189
+ ca = dec_block["cross_attn"]
190
+ for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
191
+ copy_kernel(new_state, ca[proj]["kernel"], f"{base}.cross_attn.{proj}.weight", i)
192
+ for n in ["q_norm", "k_norm"]:
193
+ copy_vector(new_state, ca[n]["scale"], f"{base}.cross_attn.{n}.scale", i)
194
+
195
+ # ---- Step 4: verify completeness before loading ----
196
+ missing = sorted(set(target_state.keys()) - set(new_state.keys()))
197
+ extra = sorted(set(new_state.keys()) - set(target_state.keys()))
198
+ if missing or extra:
199
+ print("MISSING keys (in model, not in new_state):")
200
+ for k in missing:
201
+ print(f" {k}")
202
+ print("EXTRA keys (in new_state, not in model):")
203
+ for k in extra:
204
+ print(f" {k}")
205
+ sys.exit("state_dict mismatch -- fix the mapping")
206
+
207
+ # Shape check before load_state_dict
208
+ shape_errors = []
209
+ for k in new_state:
210
+ expected = tuple(target_state[k].shape)
211
+ got = tuple(new_state[k].shape)
212
+ if expected != got:
213
+ shape_errors.append(f" {k}: model expects {expected}, got {got}")
214
+ if shape_errors:
215
+ print("SHAPE MISMATCHES:")
216
+ for e in shape_errors:
217
+ print(e)
218
+ sys.exit("shape mismatch -- fix transpositions")
219
+
220
+ # ---- Step 5: load and verify ----
221
+ result = model.load_state_dict(new_state, strict=True)
222
+ assert result.missing_keys == [] and result.unexpected_keys == [], \
223
+ f"load_state_dict unexpected result: {result}"
224
+
225
+ n = len(new_state)
226
+ print(f"\nSuccessfully loaded {n} tensors into PyTorch port (strict=True)")
227
+ print(f"Config: {config_dict}")
228
+
229
+ # ---- Step 6: save ----
230
+ out_path = Path(args.out)
231
+ torch.save(new_state, out_path)
232
+ print(f"Saved -> {out_path}")
233
+
234
+ # Also save the config as JSON next to the .pt so export_onnx.py can rebuild
235
+ # the model with the right dims for any finetuned variant.
236
+ import json
237
+ config_out = out_path.with_suffix(".config.json")
238
+ config_out.write_text(json.dumps(config_dict, indent=2))
239
+ print(f"Saved -> {config_out}")
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()